00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 #ifndef __MMX__MATRIX_THREADS__HPP
00014 #define __MMX__MATRIX_THREADS__HPP
00015 #include <algebramix/matrix_unrolled.hpp>
00016 #include <basix/threads.hpp>
00017 
00018 namespace mmx {
00019 
00020 
00021 
00022 
00023 
00024 template<typename V>
00025 struct matrix_threads: public V {
00026   typedef typename V::Vec Vec;
00027   typedef typename V::Naive Naive;
00028   typedef matrix_threads<typename V::Positive> Positive;
00029   typedef matrix_threads<typename V::No_simd> No_simd;
00030   typedef typename V::No_thread No_thread;
00031   typedef matrix_threads<typename V::No_scaled> No_scaled;
00032 };
00033 
00034 template<typename F, typename V, typename W>
00035 struct implementation<F,V,matrix_threads<W> >:
00036   public implementation<F,V,W> {};
00037 
00038 
00039 
00040 
00041 
00042 #ifdef BASIX_ENABLE_THREADS
00043 
00044 template<typename V, typename W>
00045 struct implementation<matrix_multiply_base,V,matrix_threads<W> >:
00046   public implementation<matrix_linear,V>
00047 {
00048   static const nat thr= (1 << 10);
00049   static const nat sz = 4;
00050   typedef implementation<matrix_multiply,W> Mat;
00051 
00052 template<typename Op, typename D, typename S1, typename S2>
00053 struct multiply_task_rep: public task_rep {
00054   D* d; const S1* s1; const S2* s2;
00055   nat r; nat rr; nat l; nat ll; nat c; nat cc;
00056 public:
00057   inline multiply_task_rep (D* d2, const S1* s1b, const S2* s2b,
00058                             nat r2, nat rr2, nat l2, nat ll2, nat c2, nat cc2):
00059     d (d2), s1 (s1b), s2 (s2b),
00060     r (r2), rr (rr2), l (l2), ll (ll2), c (c2), cc (cc2) {}
00061   void execute () {
00062     Mat::template mul<mul_op> (d, s1, s2, r, rr, l, ll, c, cc);
00063   }
00064 };
00065 
00066 template<typename Op, typename D, typename S1, typename S2> static inline void
00067 mul (D* d, const S1* s1, const S2* s2,
00068      nat r, nat rr, nat l, nat ll, nat c, nat cc)
00069 {
00070   typedef typename Op::acc_op Acc;
00071   if (r * c < thr) Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc);
00072   else {
00073     nat tr= r, nr= 1, tc= c, nc= 1, tt= threads_number;
00074     while (tt != 1) {
00075       if ((tt & 1) == 0) {
00076         if (tr > tc) { tr= (tr+1) >> 1; nr <<= 1; }
00077         else { tc= (tc+1) >> 1; nc <<= 1; }
00078         tt >>= 1;
00079       }
00080       else {
00081         if (tr > tc) { tr= (tr+tt-1) / tt; nr *= tt; }
00082         else { tc= (tc+tt-1) / tt; nc *= tt; }
00083         tt= 1;
00084       }
00085     }
00086     tr= sz * ((tr + sz - 1) / sz);
00087     tc= sz * ((tc + sz - 1) / sz);
00088     task tasks[nr*nc];
00089     
00090     for (nat ir=0; ir<nr; ir++)
00091       for (nat ic=0; ic<nc; ic++) {
00092         nat r1= ir * tr, r2= min (r1 + tr, r);
00093         nat c1= ic * tc, c2= min (c1 + tc, c);
00094         if (r1 < r && c1 < c) {
00095           D*        td = d  + Mat::index (r1, c1, rr, cc);
00096           const S1* ts1= s1 + Mat::index (r1, 0 , rr, ll);
00097           const S2* ts2= s2 + Mat::index (0 , c1, ll, cc);
00098           tasks[Mat::index (ir, ic, nr, nc)]=
00099             new multiply_task_rep<Op,D,S1,S2>
00100                   (td, ts1, ts2, r2-r1, rr, l, ll, c2-c1, cc);
00101         }
00102       }
00103     threads_execute (tasks, nr*nc);
00104   }
00105 }
00106 
00107 template<typename D, typename S1, typename S2> static inline void
00108 mul (D* dest, const S1* m1, const S2* m2, nat r, nat l, nat c) {
00109   mul<mul_op> (dest, m1, m2, r, r, l, l, c, c);
00110 }
00111 
00112 }; 
00113 
00114 #endif // BASIX_ENABLE_THREADS
00115 
00116 } 
00117 #endif //__MMX__MATRIX_THREADS__HPP