00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__MATRIX_THREADS__HPP
00014 #define __MMX__MATRIX_THREADS__HPP
00015 #include <algebramix/matrix_unrolled.hpp>
00016 #include <basix/threads.hpp>
00017
00018 namespace mmx {
00019
00020
00021
00022
00023
00024 template<typename V>
00025 struct matrix_threads: public V {
00026 typedef typename V::Vec Vec;
00027 typedef typename V::Naive Naive;
00028 typedef matrix_threads<typename V::Positive> Positive;
00029 typedef matrix_threads<typename V::No_simd> No_simd;
00030 typedef typename V::No_thread No_thread;
00031 typedef matrix_threads<typename V::No_scaled> No_scaled;
00032 };
00033
00034 template<typename F, typename V, typename W>
00035 struct implementation<F,V,matrix_threads<W> >:
00036 public implementation<F,V,W> {};
00037
00038
00039
00040
00041
00042 #ifdef BASIX_ENABLE_THREADS
00043
00044 template<typename V, typename W>
00045 struct implementation<matrix_multiply_base,V,matrix_threads<W> >:
00046 public implementation<matrix_linear,V>
00047 {
00048 static const nat thr= 1024;
00049 static const nat sz = 4;
00050 typedef implementation<matrix_multiply,V,W> Mat;
00051
00052 template<typename Op, typename D, typename S1, typename S2>
00053 struct multiply_task_rep: public task_rep {
00054 D* d; const S1* s1; const S2* s2;
00055 nat r; nat rr; nat l; nat ll; nat c; nat cc;
00056
00057 public:
00058 inline multiply_task_rep (D* d2, const S1* s1b, const S2* s2b,
00059 nat r2, nat rr2, nat l2, nat ll2, nat c2, nat cc2):
00060 d (d2), s1 (s1b), s2 (s2b),
00061 r (r2), rr (rr2), l (l2), ll (ll2), c (c2), cc (cc2)
00062 {
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074 }
00075 inline ~multiply_task_rep () {
00076
00077
00078
00079
00080
00081
00082
00083
00084 }
00085 void execute () {
00086
00087 Mat::template mul<mul_op> (d, s1, s2, r, rr, l, ll, c, cc);
00088 }
00089 };
00090
00091 template<typename Op, typename D, typename S1, typename S2> static inline void
00092 mul (D* d, const S1* s1, const S2* s2,
00093 nat r, nat rr, nat l, nat ll, nat c, nat cc)
00094 {
00095 typedef typename Op::acc_op Acc;
00096 if (r * c < thr) Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc);
00097 else {
00098 nat tr= r, nr= 1, tc= c, nc= 1, tt= threads_number;
00099 while (tt != 1) {
00100 if ((tt & 1) == 0) {
00101 if (tr > tc) { tr= (tr+1) >> 1; nr <<= 1; }
00102 else { tc= (tc+1) >> 1; nc <<= 1; }
00103 tt >>= 1;
00104 }
00105 else {
00106 if (tr > tc) { tr= (tr+tt-1) / tt; nr *= tt; }
00107 else { tc= (tc+tt-1) / tt; nc *= tt; }
00108 tt= 1;
00109 }
00110 }
00111 tr= sz * ((tr + sz - 1) / sz);
00112 tc= sz * ((tc + sz - 1) / sz);
00113
00114 task tasks[nr*nc];
00115 for (nat ir=0; ir<nr; ir++)
00116 for (nat ic=0; ic<nc; ic++) {
00117 nat r1= ir * tr, r2= min (r1 + tr, r);
00118 nat c1= ic * tc, c2= min (c1 + tc, c);
00119 if (r1 < r && c1 < c) {
00120 D* td = d + Mat::index (r1, c1, rr, cc);
00121 const S1* ts1= s1 + Mat::index (r1, 0 , rr, ll);
00122 const S2* ts2= s2 + Mat::index (0 , c1, ll, cc);
00123 tasks[Mat::index (ir, ic, nr, nc)]=
00124 new multiply_task_rep<Op,D,S1,S2>
00125 (td, ts1, ts2, r2-r1, rr, l, ll, c2-c1, cc);
00126 }
00127 }
00128 threads_execute (tasks, nr*nc);
00129 }
00130 }
00131
00132 };
00133
00134 #endif // BASIX_ENABLE_THREADS
00135
00136 }
00137 #endif //__MMX__MATRIX_THREADS__HPP