00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__MATRIX_STRASSEN__HPP
00014 #define __MMX__MATRIX_STRASSEN__HPP
00015 #include <algebramix/matrix_unrolled.hpp>
00016
00017 namespace mmx {
00018
00019
00020
00021
00022
00023 template<typename V>
00024 struct matrix_strassen: public V {
00025 typedef typename V::Vec Vec;
00026 typedef typename V::Naive Naive;
00027 typedef typename V::Positive Positive;
00028 typedef matrix_strassen<typename V::No_simd> No_simd;
00029 typedef matrix_strassen<typename V::No_thread> No_thread;
00030 typedef matrix_strassen<typename V::No_scaled> No_scaled;
00031 };
00032
00033 template<typename F, typename V, typename W>
00034 struct implementation<F,V,matrix_strassen<W> >:
00035 public implementation<F,V,W> {};
00036
00037
00038
00039
00040
00041 template<typename V, typename W>
00042 struct implementation<matrix_multiply_base,V,matrix_strassen<W> >:
00043 public implementation<matrix_linear,V>
00044 {
00045 static const nat thr= 128;
00046 typedef implementation<vector_linear,W> Vec;
00047 typedef implementation<matrix_multiply,W> Mat;
00048
00049 template<typename C> static inline void
00050 mat_load (C* d, nat r, nat c, const C* s, nat rr, nat cc) {
00051 nat j= c;
00052 for (; j!=0; j--, d += Mat::index(0,1,r,c), s += Mat::index(0,2,rr,cc)) {
00053 nat i = r;
00054 C* dd= d;
00055 const C* ss= s;
00056 for (; i!=0; i--, dd += Mat::index(1,0,r,c), ss += Mat::index(2,0,rr,cc))
00057 *dd= *ss;
00058 }
00059 }
00060
00061 template<typename Op, typename C> static inline void
00062 mat_save (C* d, nat rr, nat cc, const C* s, nat r, nat c) {
00063 typedef typename Op::nomul_op Set;
00064 nat j= c;
00065 for (; j!=0; j--, d += Mat::index(0,2,rr,cc), s += Mat::index(0,1,r,c)) {
00066 nat i = r;
00067 C* dd= d;
00068 const C* ss= s;
00069 for (; i!=0; i--, dd += Mat::index(2,0,rr,cc), ss += Mat::index(1,0,r,c))
00070 Set::set_op (*dd, *ss);
00071 }
00072 }
00073
00074 template<typename Op, typename D, typename S1, typename S2> static void
00075 mul (D* d, const S1* a, const S2* b,
00076 nat r, nat rr, nat l, nat ll, nat c, nat cc)
00077 {
00078 if (r < thr || l < thr || c < thr) {
00079 Mat::template mul<Op> (d, a, b, r, rr, l, ll, c, cc);
00080 return;
00081 }
00082
00083 nat hr= r>>1, hl= l>>1, hc= c>>1, fr= hr<<1, fl= hl<<1, fc= hc<<1;
00084
00085 nat sza= aligned_size<S1,W> (hr * hl);
00086 S1* a11= mmx_new<S1> (5 * sza);
00087 S1* a12= a11 + sza;
00088 S1* a21= a12 + sza;
00089 S1* a22= a21 + sza;
00090 S1* aaa= a22 + sza;
00091
00092 nat szb= aligned_size<S2,W> (hl * hc);
00093 S2* b11= mmx_new<S2> (5 * szb);
00094 S2* b12= b11 + szb;
00095 S2* b21= b12 + szb;
00096 S2* b22= b21 + szb;
00097 S2* bbb= b22 + szb;
00098
00099 nat szd= aligned_size<D,W> (hr * hc);
00100 D* m1 = mmx_new<D> (11 * szd);
00101 D* m2 = m1 + szd;
00102 D* m3 = m2 + szd;
00103 D* m4 = m3 + szd;
00104 D* m5 = m4 + szd;
00105 D* m6 = m5 + szd;
00106 D* m7 = m6 + szd;
00107 D* d11= m7 + szd;
00108 D* d12= d11 + szd;
00109 D* d21= d12 + szd;
00110 D* d22= d21 + szd;
00111
00112 mat_load (a11, hr, hl, a , rr, ll);
00113 mat_load (a12, hr, hl, a + Mat::index (0, 1, rr, ll), rr, ll);
00114 mat_load (a21, hr, hl, a + Mat::index (1, 0, rr, ll), rr, ll);
00115 mat_load (a22, hr, hl, a + Mat::index (1, 1, rr, ll), rr, ll);
00116 mat_load (b11, hr, hl, b , rr, ll);
00117 mat_load (b12, hr, hl, b + Mat::index (0, 1, rr, ll), rr, ll);
00118 mat_load (b21, hr, hl, b + Mat::index (1, 0, rr, ll), rr, ll);
00119 mat_load (b22, hr, hl, b + Mat::index (1, 1, rr, ll), rr, ll);
00120
00121 Vec::add (aaa, a11, a22, hr * hl);
00122 Vec::add (bbb, b11, b22, hl * hc);
00123 mul<mul_op> (m1, aaa, bbb, hr, hr, hl, hl, hc, hc);
00124 Vec::add (aaa, a21, a22, hr * hl);
00125 mul<mul_op> (m2, aaa, b11, hr, hr, hl, hl, hc, hc);
00126 Vec::sub (bbb, b12, b22, hl * hc);
00127 mul<mul_op> (m3, a11, bbb, hr, hr, hl, hl, hc, hc);
00128 Vec::sub (bbb, b21, b11, hl * hc);
00129 mul<mul_op> (m4, a22, bbb, hr, hr, hl, hl, hc, hc);
00130 Vec::add (aaa, a11, a12, hr * hl);
00131 mul<mul_op> (m5, aaa, b22, hr, hr, hl, hl, hc, hc);
00132 Vec::sub (aaa, a21, a11, hr * hl);
00133 Vec::add (bbb, b11, b12, hl * hc);
00134 mul<mul_op> (m6, aaa, bbb, hr, hr, hl, hl, hc, hc);
00135 Vec::sub (aaa, a12, a22, hr * hl);
00136 Vec::add (bbb, b21, b22, hl * hc);
00137 mul<mul_op> (m7, aaa, bbb, hr, hr, hl, hl, hc, hc);
00138
00139 Vec::add (d11, m1, m4, hr * hc);
00140 Vec::sub (d11, m5, hr * hc);
00141 Vec::add (d11, m7, hr * hc);
00142 Vec::add (d12, m3, m5, hr * hc);
00143 Vec::add (d21, m2, m4, hr * hc);
00144 Vec::sub (d22, m1, m2, hr * hc);
00145 Vec::add (d22, m3, hr * hc);
00146 Vec::add (d22, m6, hr * hc);
00147
00148 mat_save<Op> (d , rr, cc, d11, hr, hc);
00149 mat_save<Op> (d + Mat::index (0, 1, rr, cc), rr, ll, d12, hr, hc);
00150 mat_save<Op> (d + Mat::index (1, 0, rr, cc), rr, ll, d21, hr, hc);
00151 mat_save<Op> (d + Mat::index (1, 1, rr, cc), rr, ll, d22, hr, hc);
00152
00153 mmx_delete<S1> (a11, 5 * sza);
00154 mmx_delete<S2> (b11, 5 * szb);
00155 mmx_delete<D > (m1 , 11 * szd);
00156
00157 mul_complete<Op,W> (d, a, b, r, rr, l, ll, c, cc, fr, fl, fc);
00158 }
00159 };
00160
00161 }
00162 #endif //__MMX__MATRIX_STRASSEN__HPP