00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__MATRIX_UNROLLED__HPP
00014 #define __MMX__MATRIX_UNROLLED__HPP
00015 #include <algebramix/vector_unrolled.hpp>
00016 #include <algebramix/matrix_fixed.hpp>
00017
00018 namespace mmx {
00019
00020
00021
00022
00023
00024 template<nat sz, typename V=matrix_naive>
00025 struct matrix_unrolled: public V {
00026 typedef vector_unrolled<sz,typename V::Vec> Vec;
00027 typedef typename V::Naive Naive;
00028 typedef matrix_unrolled<sz,typename V::Positive> Positive;
00029 typedef matrix_unrolled<sz,typename V::No_aligned> No_aligned;
00030 typedef matrix_unrolled<sz,typename V::No_simd> No_simd;
00031 typedef matrix_unrolled<sz,typename V::No_thread> No_thread;
00032 typedef matrix_unrolled<sz,typename V::No_scaled> No_scaled;
00033 };
00034
00035 template<nat sz, typename F, typename V, typename W>
00036 struct implementation<F,V,matrix_unrolled<sz,W> >:
00037 public implementation<F,V,W> {};
00038
00039
00040
00041
00042
00043 template<typename Op, typename V, typename D, typename S1, typename S2> void
00044 mul_complete (D* dest, const S1* src1, const S2* src2,
00045 nat r, nat rr, nat l, nat ll, nat c, nat cc,
00046 nat hr, nat hl, nat hc)
00047 {
00048 typedef implementation<matrix_multiply,V> Mat;
00049 typedef typename Op::acc_op Acc;
00050 if (hr < r && hl != 0 && hc != 0)
00051 Mat::template mul<Op > (dest + Mat::index (hr, 0, rr, cc),
00052 src1 + Mat::index (hr, 0, rr, ll),
00053 src2,
00054 r-hr, rr, hl, ll, hc, cc);
00055 if (hc < c && hl != 0)
00056 Mat::template mul<Op > (dest + Mat::index (0, hc, rr, cc),
00057 src1,
00058 src2 + Mat::index (0, hc, ll, cc),
00059 r , rr, hl, ll, c-hc, cc);
00060 if (hl < l)
00061 Mat::template mul<Acc> (dest,
00062 src1 + Mat::index (0, hl, rr, ll),
00063 src2 + Mat::index (hl, 0, ll, cc),
00064 r , rr, l-hl, ll, c , cc);
00065 }
00066
00067
00068
00069
00070
00071 template<typename Op, nat ur, nat ul, nat uc, typename V,
00072 typename D, typename S1, typename S2> void
00073 mul_unrolled (D* dest, const S1* src1, const S2* src2,
00074 nat r, nat rr, nat l, nat ll, nat c, nat cc)
00075 {
00076 typedef implementation<matrix_multiply,V> Mat;
00077 typedef implementation<matrix_multiply_base,matrix_naive> NMat;
00078 typedef typename Op::acc_op Acc;
00079 nat nr= r/ur, nl= l/ul, nc= c/uc;
00080 if (nl == 0)
00081 NMat::template clr<Op> (dest, r, rr, c, cc);
00082 else
00083 for (nat ir=0; ir<nr; ir++)
00084 for (nat ic=0; ic<nc; ic++) {
00085 nat il=0;
00086 for (; il<1; il++)
00087 matrix_multiply_helper<Op,D,S1,S2,ur,ul,uc>::
00088 mul_stride (dest + Mat::index (ir*ur, ic*uc, rr, cc),
00089 src1 + Mat::index (ir*ur, il*ul, rr, ll),
00090 src2 + Mat::index (il*ul, ic*uc, ll, cc),
00091 rr, ll);
00092 for (; il<nl; il++)
00093 matrix_multiply_helper<Acc,D,S1,S2,ur,ul,uc>::
00094 mul_stride (dest + Mat::index (ir*ur, ic*uc, rr, cc),
00095 src1 + Mat::index (ir*ur, il*ul, rr, ll),
00096 src2 + Mat::index (il*ul, ic*uc, ll, cc),
00097 rr, ll);
00098 }
00099 mul_complete<Op,V> (dest, src1, src2, r, rr, l, ll, c, cc,
00100 ur*nr, ul*nl, uc*nc);
00101 }
00102
00103
00104
00105
00106
00107 template<nat sz, typename V, typename W>
00108 struct implementation<matrix_multiply_base,V,matrix_unrolled<sz,W> >:
00109 public implementation<matrix_linear,V>
00110 {
00111 const static nat ur= sz;
00112 const static nat ul= sz;
00113 const static nat uc= sz;
00114
00115 template<typename Op, typename D, typename S1, typename S2>
00116 static inline void
00117 mul (D* dest, const S1* src1, const S2* src2,
00118 nat r, nat rr, nat l, nat ll, nat c, nat cc)
00119 {
00120 mul_unrolled<Op,ur,ul,uc,W> (dest, src1, src2, r, rr, l, ll, c, cc);
00121 }
00122 };
00123
00124 }
00125 #endif //__MMX__MATRIX_UNROLLED__HPP