00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__MATRIX_THREADS__HPP
00014 #define __MMX__MATRIX_THREADS__HPP
00015 #include <algebramix/matrix_unrolled.hpp>
00016 #include <basix/threads.hpp>
00017
00018 namespace mmx {
00019
00020
00021
00022
00023
00024 template<typename V>
00025 struct matrix_threads: public V {
00026 typedef typename V::Vec Vec;
00027 typedef typename V::Naive Naive;
00028 typedef matrix_threads<typename V::Positive> Positive;
00029 typedef matrix_threads<typename V::No_simd> No_simd;
00030 typedef typename V::No_thread No_thread;
00031 typedef matrix_threads<typename V::No_scaled> No_scaled;
00032 };
00033
00034 template<typename F, typename V, typename W>
00035 struct implementation<F,V,matrix_threads<W> >:
00036 public implementation<F,V,W> {};
00037
00038
00039
00040
00041
00042 #ifdef BASIX_ENABLE_THREADS
00043
00044 template<typename V, typename W>
00045 struct implementation<matrix_multiply_base,V,matrix_threads<W> >:
00046 public implementation<matrix_linear,V>
00047 {
00048 static const nat thr= (1 << 10);
00049 static const nat sz = 4;
00050 typedef implementation<matrix_multiply,W> Mat;
00051
00052 template<typename Op, typename D, typename S1, typename S2>
00053 struct multiply_task_rep: public task_rep {
00054 D* d; const S1* s1; const S2* s2;
00055 nat r; nat rr; nat l; nat ll; nat c; nat cc;
00056 public:
00057 inline multiply_task_rep (D* d2, const S1* s1b, const S2* s2b,
00058 nat r2, nat rr2, nat l2, nat ll2, nat c2, nat cc2):
00059 d (d2), s1 (s1b), s2 (s2b),
00060 r (r2), rr (rr2), l (l2), ll (ll2), c (c2), cc (cc2) {}
00061 void execute () {
00062 Mat::template mul<mul_op> (d, s1, s2, r, rr, l, ll, c, cc);
00063 }
00064 };
00065
00066 template<typename Op, typename D, typename S1, typename S2> static inline void
00067 mul (D* d, const S1* s1, const S2* s2,
00068 nat r, nat rr, nat l, nat ll, nat c, nat cc)
00069 {
00070 typedef typename Op::acc_op Acc;
00071 if (r * c < thr) Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc);
00072 else {
00073 nat tr= r, nr= 1, tc= c, nc= 1, tt= threads_number;
00074 while (tt != 1) {
00075 if ((tt & 1) == 0) {
00076 if (tr > tc) { tr= (tr+1) >> 1; nr <<= 1; }
00077 else { tc= (tc+1) >> 1; nc <<= 1; }
00078 tt >>= 1;
00079 }
00080 else {
00081 if (tr > tc) { tr= (tr+tt-1) / tt; nr *= tt; }
00082 else { tc= (tc+tt-1) / tt; nc *= tt; }
00083 tt= 1;
00084 }
00085 }
00086 tr= sz * ((tr + sz - 1) / sz);
00087 tc= sz * ((tc + sz - 1) / sz);
00088 task tasks[nr*nc];
00089
00090 for (nat ir=0; ir<nr; ir++)
00091 for (nat ic=0; ic<nc; ic++) {
00092 nat r1= ir * tr, r2= min (r1 + tr, r);
00093 nat c1= ic * tc, c2= min (c1 + tc, c);
00094 if (r1 < r && c1 < c) {
00095 D* td = d + Mat::index (r1, c1, rr, cc);
00096 const S1* ts1= s1 + Mat::index (r1, 0 , rr, ll);
00097 const S2* ts2= s2 + Mat::index (0 , c1, ll, cc);
00098 tasks[Mat::index (ir, ic, nr, nc)]=
00099 new multiply_task_rep<Op,D,S1,S2>
00100 (td, ts1, ts2, r2-r1, rr, l, ll, c2-c1, cc);
00101 }
00102 }
00103 threads_execute (tasks, nr*nc);
00104 }
00105 }
00106
00107 template<typename D, typename S1, typename S2> static inline void
00108 mul (D* dest, const S1* m1, const S2* m2, nat r, nat l, nat c) {
00109 mul<mul_op> (dest, m1, m2, r, r, l, l, c, c);
00110 }
00111
00112 };
00113
00114 #endif // BASIX_ENABLE_THREADS
00115
00116 }
00117 #endif //__MMX__MATRIX_THREADS__HPP