00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__MATRIX_STRASSEN__HPP
00014 #define __MMX__MATRIX_STRASSEN__HPP
00015 #include <algebramix/matrix_unrolled.hpp>
00016
00017 namespace mmx {
00018
00019
00020
00021
00022
00023 template<typename V>
00024 struct matrix_strassen: public V {
00025 typedef typename V::Vec Vec;
00026 typedef typename V::Naive Naive;
00027 typedef typename V::Positive Positive;
00028 typedef matrix_strassen<typename V::No_simd> No_simd;
00029 typedef matrix_strassen<typename V::No_thread> No_thread;
00030 typedef matrix_strassen<typename V::No_scaled> No_scaled;
00031 };
00032
00033 template<typename F, typename V, typename W>
00034 struct implementation<F,V,matrix_strassen<W> >:
00035 public implementation<F,V,W> {};
00036
00037
00038
00039
00040
00041 template<typename C, typename V>
00042 struct threshold_helper<C, matrix_multiply_threshold<matrix_strassen<V> > > {
00043 typedef fixed_value<nat,128> impl;
00044 };
00045
00046 template<typename V, typename W>
00047 struct implementation<matrix_multiply_base,V,matrix_strassen<W> >:
00048 public implementation<matrix_linear,V>
00049 {
00050 typedef implementation<vector_linear,V> Vec;
00051 typedef implementation<matrix_multiply,W> Mat;
00052 typedef matrix_multiply_threshold<matrix_strassen<W> > Thr;
00053
00054 template<typename C> static inline void
00055 mat_load (C* d, nat r, nat c, const C* s, nat rr, nat cc) {
00056 nat j= c;
00057 for (; j!=0; j--, d += Mat::index(0,1,r,c), s += Mat::index(0,2,rr,cc)) {
00058 nat i = r;
00059 C* dd= d;
00060 const C* ss= s;
00061 for (; i!=0; i--, dd += Mat::index(1,0,r,c), ss += Mat::index(2,0,rr,cc))
00062 *dd= *ss;
00063 }
00064 }
00065
00066 template<typename Op, typename C> static inline void
00067 mat_save (C* d, nat rr, nat cc, const C* s, nat r, nat c) {
00068 typedef typename Op::nomul_op Set;
00069 nat j= c;
00070 for (; j!=0; j--, d += Mat::index(0,2,rr,cc), s += Mat::index(0,1,r,c)) {
00071 nat i = r;
00072 C* dd= d;
00073 const C* ss= s;
00074 for (; i!=0; i--, dd += Mat::index(2,0,rr,cc), ss += Mat::index(1,0,r,c))
00075 Set::set_op (*dd, *ss);
00076 }
00077 }
00078
00079 template<typename Op, typename D, typename S1, typename S2> static void
00080 mul (D* d, const S1* a, const S2* b,
00081 nat r, nat rr, nat l, nat ll, nat c, nat cc)
00082 {
00083 static const nat thr= Threshold(D,Thr);
00084 if (r < thr || l < thr || c < thr) {
00085 Mat::template mul<Op> (d, a, b, r, rr, l, ll, c, cc);
00086 return;
00087 }
00088
00089 nat hr= r>>1, hl= l>>1, hc= c>>1, fr= hr<<1, fl= hl<<1, fc= hc<<1;
00090
00091 nat sza= aligned_size<S1,W> (hr * hl);
00092 S1* a11= mmx_new<S1> (5 * sza);
00093 S1* a12= a11 + sza;
00094 S1* a21= a12 + sza;
00095 S1* a22= a21 + sza;
00096 S1* aaa= a22 + sza;
00097
00098 nat szb= aligned_size<S2,W> (hl * hc);
00099 S2* b11= mmx_new<S2> (5 * szb);
00100 S2* b12= b11 + szb;
00101 S2* b21= b12 + szb;
00102 S2* b22= b21 + szb;
00103 S2* bbb= b22 + szb;
00104
00105 nat szd= aligned_size<D,W> (hr * hc);
00106 D* m1 = mmx_new<D> (11 * szd);
00107 D* m2 = m1 + szd;
00108 D* m3 = m2 + szd;
00109 D* m4 = m3 + szd;
00110 D* m5 = m4 + szd;
00111 D* m6 = m5 + szd;
00112 D* m7 = m6 + szd;
00113 D* d11= m7 + szd;
00114 D* d12= d11 + szd;
00115 D* d21= d12 + szd;
00116 D* d22= d21 + szd;
00117
00118 mat_load (a11, hr, hl, a , rr, ll);
00119 mat_load (a12, hr, hl, a + Mat::index (0, 1, rr, ll), rr, ll);
00120 mat_load (a21, hr, hl, a + Mat::index (1, 0, rr, ll), rr, ll);
00121 mat_load (a22, hr, hl, a + Mat::index (1, 1, rr, ll), rr, ll);
00122 mat_load (b11, hr, hl, b , rr, ll);
00123 mat_load (b12, hr, hl, b + Mat::index (0, 1, rr, ll), rr, ll);
00124 mat_load (b21, hr, hl, b + Mat::index (1, 0, rr, ll), rr, ll);
00125 mat_load (b22, hr, hl, b + Mat::index (1, 1, rr, ll), rr, ll);
00126
00127 Vec::add (aaa, a11, a22, hr * hl);
00128 Vec::add (bbb, b11, b22, hl * hc);
00129 mul<mul_op> (m1, aaa, bbb, hr, hr, hl, hl, hc, hc);
00130 Vec::add (aaa, a21, a22, hr * hl);
00131 mul<mul_op> (m2, aaa, b11, hr, hr, hl, hl, hc, hc);
00132 Vec::sub (bbb, b12, b22, hl * hc);
00133 mul<mul_op> (m3, a11, bbb, hr, hr, hl, hl, hc, hc);
00134 Vec::sub (bbb, b21, b11, hl * hc);
00135 mul<mul_op> (m4, a22, bbb, hr, hr, hl, hl, hc, hc);
00136 Vec::add (aaa, a11, a12, hr * hl);
00137 mul<mul_op> (m5, aaa, b22, hr, hr, hl, hl, hc, hc);
00138 Vec::sub (aaa, a21, a11, hr * hl);
00139 Vec::add (bbb, b11, b12, hl * hc);
00140 mul<mul_op> (m6, aaa, bbb, hr, hr, hl, hl, hc, hc);
00141 Vec::sub (aaa, a12, a22, hr * hl);
00142 Vec::add (bbb, b21, b22, hl * hc);
00143 mul<mul_op> (m7, aaa, bbb, hr, hr, hl, hl, hc, hc);
00144
00145 Vec::add (d11, m1, m4, hr * hc);
00146 Vec::sub (d11, m5, hr * hc);
00147 Vec::add (d11, m7, hr * hc);
00148 Vec::add (d12, m3, m5, hr * hc);
00149 Vec::add (d21, m2, m4, hr * hc);
00150 Vec::sub (d22, m1, m2, hr * hc);
00151 Vec::add (d22, m3, hr * hc);
00152 Vec::add (d22, m6, hr * hc);
00153
00154 mat_save<Op> (d , rr, cc, d11, hr, hc);
00155 mat_save<Op> (d + Mat::index (0, 1, rr, cc), rr, ll, d12, hr, hc);
00156 mat_save<Op> (d + Mat::index (1, 0, rr, cc), rr, ll, d21, hr, hc);
00157 mat_save<Op> (d + Mat::index (1, 1, rr, cc), rr, ll, d22, hr, hc);
00158
00159 mmx_delete<S1> (a11, 5 * sza);
00160 mmx_delete<S2> (b11, 5 * szb);
00161 mmx_delete<D > (m1 , 11 * szd);
00162
00163 mul_complete<Op,W> (d, a, b, r, rr, l, ll, c, cc, fr, fl, fc);
00164 }
00165
00166 template<typename D, typename S1, typename S2> static inline void
00167 mul (D* dest, const S1* m1, const S2* m2, nat r, nat l, nat c) {
00168 mul<mul_op> (dest, m1, m2, r, r, l, l, c, c);
00169 }
00170
00171 };
00172
00173 }
00174 #endif //__MMX__MATRIX_STRASSEN__HPP