00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__MATRIX_BALANCED__HPP
00014 #define __MMX__MATRIX_BALANCED__HPP
00015 #include <basix/operators.hpp>
00016 #include <algebramix/matrix_naive.hpp>
00017
00018 namespace mmx {
00019
00020
00021
00022
00023
00024 template<typename V>
00025 struct matrix_balanced: public V {
00026 typedef typename V::Vec Vec;
00027 typedef typename V::Naive Naive;
00028 typedef matrix_balanced<typename V::Positive> Positive;
00029 typedef matrix_balanced<typename V::No_simd> No_simd;
00030 typedef matrix_balanced<typename V::No_thread> No_thread;
00031 typedef matrix_balanced<typename V::No_scaled> No_scaled;
00032 };
00033
00034 template<typename F, typename V, typename W>
00035 struct implementation<F,V,matrix_balanced<W> >:
00036 public implementation<F,V,W> {};
00037
00038
00039
00040
00041
00042 template<typename V>
00043 struct matrix_balanced_threshold {};
00044
00045 template<typename W>
00046 struct threshold_helper<integer,matrix_balanced_threshold<W> > {
00047 typedef fixed_value<nat,2> impl;
00048 };
00049
00050
00051
00052
00053
00054 template<typename V, typename W>
00055 struct implementation<matrix_multiply,V,matrix_balanced<W> >:
00056 public implementation<matrix_multiply_base,V>
00057 {
00058 typedef matrix_multiply_threshold<matrix_balanced<W> > Th;
00059 typedef matrix_balanced_threshold<matrix_balanced<W> > BTh;
00060 typedef implementation<matrix_multiply,W> Mat;
00061 typedef implementation<vector_linear, typename W::Vec> Vec;
00062
00063 template<typename C>
00064 static inline nat mat_size (const C* s, nat r, nat c) {
00065 nat sz= 0;
00066 for (nat i= 0; i < r * c; i++) sz= max (sz, N(s[i]));
00067 return sz; }
00068
00069 template<typename C>
00070 static inline void mat_rshift (C* d, const C* s, nat r, nat c, nat h) {
00071 for (nat i= 0; i < r * c; i++) d[i]= rshiftz (s[i], h); }
00072
00073 template<typename C>
00074 static inline void mat_co_rshift (C* d, const C* s, nat r, nat c, nat h) {
00075 for (nat i= 0; i < r * c; i++) d[i]= co_rshiftz (s[i], h); }
00076
00077 template<typename C>
00078 static inline void mat_lshift (C* d, const C* s, nat r, nat c, nat h) {
00079 for (nat i= 0; i < r * c; i++) d[i]= lshiftz (s[i], h); }
00080
00081 template<typename Op, typename D, typename S1, typename S2>
00082 static inline void
00083 mul (D* d, const S1* s1, const S2* s2,
00084 nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00085 Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc); }
00086
00087 template<typename D, typename S1, typename S2>
00088 static inline void
00089 mul (D* d, const S1* s1, const S2* s2,
00090 nat r, nat l, nat c) {
00091 Mat::template mul<mul_op> (d, s1, s2, r, r, l, l, c, c); }
00092
00093 template<typename C>
00094 static void
00095 mul (C* d, const C* s1, const C* s2,
00096 nat r, nat l, nat c) {
00097 const nat th= Threshold(C,Th);
00098 const nat balance_threshold= Threshold(C,BTh);
00099 nat sz1= mat_size (s1, r, l);
00100 nat sz2= mat_size (s2, l, c);
00101 if (sz1 == 0 || sz2 == 0)
00102 Mat::clear (d, r, c);
00103 else if (sz1 <= th && sz2 <= th)
00104 Mat::template mul (d, s1, s2, r, l, c);
00105 else if (balance_threshold * sz1 < sz2) {
00106 nat h= sz2 >> 1;
00107 nat len_aux= aligned_size<C,V> (l, c);
00108 C* aux= mmx_new<C> (len_aux);
00109 nat len_tmp= aligned_size<C,V> (r, c);
00110 C* tmp= mmx_new<C> (len_tmp);
00111 mat_rshift (aux, s2, l, c, h);
00112 mul (tmp, s1, aux, r, l, c);
00113 mat_lshift (d, tmp, r, c, h);
00114 mat_co_rshift (aux, s2, l, c, h);
00115 mul (tmp, s1, aux, r, l, c);
00116 Vec::add (d, tmp, r * c);
00117 mmx_delete<C> (aux, len_aux);
00118 mmx_delete<C> (tmp, len_tmp);
00119 }
00120 else if (balance_threshold * sz2 < sz1) {
00121 nat h= sz1 >> 1;
00122 nat len_aux= aligned_size<C,V> (r, l);
00123 C* aux= mmx_new<C> (len_aux);
00124 nat len_tmp= aligned_size<C,V> (r, c);
00125 C* tmp= mmx_new<C> (len_tmp);
00126 mat_rshift (aux, s1, r, l, h);
00127 mul (tmp, aux, s2, r, l, c);
00128 mat_lshift (d, tmp, r, c, h);
00129 mat_co_rshift (aux, s1, r, l, h);
00130 mul (tmp, aux, s2, r, l, c);
00131 Vec::add (d, tmp, r * c);
00132 mmx_delete<C> (aux, len_aux);
00133 mmx_delete<C> (tmp, len_tmp);
00134 }
00135 else {
00136 Mat::template mul (d, s1, s2, r, l, c); } }
00137
00138 };
00139
00140 }
00141 #endif //__MMX__MATRIX_BALANCED__HPP