00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 #ifndef __MMX_MATRIX_TFT_HPP
00014 #define __MMX_MATRIX_TFT_HPP
00015 #include <algebramix/matrix.hpp>
00016 #include <algebramix/fft_truncated.hpp>
00017 namespace mmx {
00018 
00019 
00020 
00021 
00022 
00023 template<typename V>
00024 struct matrix_tft: public V {
00025   typedef typename V::Vec Vec;
00026   typedef typename V::Naive Naive;
00027   typedef typename V::Positive Positive;
00028   typedef matrix_tft<typename V::No_simd> No_simd;
00029   typedef matrix_tft<typename V::No_thread> No_thread;
00030   typedef matrix_tft<typename V::No_scaled> No_scaled;
00031 };
00032 
00033 template<typename F, typename V, typename W>
00034 struct implementation<F,V,matrix_tft<W> >:
00035   public implementation<F,V,W> {};
00036 
00037 
00038 
00039 
00040 
00041 template<typename C>
00042 struct matrix_tft_multiply_helper {
00043 
00044   typedef Scalar_type (C) F; 
00045   typedef fft_truncated_transformer<F> tft_transformer;
00046   static nat size (const C* s1, nat s1_rs, nat s1_cs,
00047                    const C* s2, nat s2_rs, nat s2_cs,
00048                    nat r, nat l, nat c) {
00049     nat sz= 0;
00050     for (nat k= 0; k < l; k++) {
00051       nat sz1= 0, sz2= 0;
00052       const C* ss1= s1 + k * s1_cs;
00053       const C* ss2= s2 + k * s2_rs;
00054       for (nat i= 0; i < r; i++, ss1 += s1_rs) sz1= max (sz1, N (*ss1));
00055       for (nat j= 0; j < c; j++, ss2 += s2_cs) sz2= max (sz2, N (*ss2));
00056       sz= max (sz, (sz1 == 0 || sz2 == 0) ? 0 : (sz1 + sz2 - 1));
00057     }
00058     return sz; }
00059 };
00060 
00061 
00062 
00063 
00064 
00065 template<typename V, typename W>
00066 struct implementation<matrix_multiply_base,V,matrix_tft<W> >:
00067   public implementation<matrix_linear,V>
00068 {
00069   typedef implementation<matrix_multiply,W> Mat;
00070 
00071 template<typename C, typename F, typename MV, typename Tfter>
00072 static void
00073 mat_direct_tft (matrix<F,MV>* dest, const C* s,
00074                 nat s_rs, nat s_cs, nat r, nat c, Tfter& tfter) {
00075   nat n= tfter.len;
00076   nat m= next_power_of_two (n);
00077   for (nat k= 0; k < n; k++)
00078     dest[k]= matrix<F,MV> (F (), r, c);
00079   F* aux= mmx_new<F> (m);
00080   for (nat i= 0; i < r; i++)
00081     for (nat j= 0; j < c; j++) {
00082       for (nat k= 0; k < n ; k++)
00083         aux[k]= s[i * s_rs + j * s_cs] [k];
00084       tfter.direct_transform (aux);
00085       for (nat k= 0; k < n ; k++)
00086         dest[k](i,j)= aux[k];
00087     }
00088   mmx_delete<F> (aux, m); }
00089 
00090 template<typename Op,typename C, typename F, typename MV, typename Tfter>
00091 static void
00092 mat_inverse_tft (C* d, nat d_rs, nat d_cs, nat r, nat c,
00093                  const matrix<F,MV>* s, Tfter& tfter) {
00094   typedef typename Op::nomul_op Acc;
00095   typedef typename Vector_variant(F) VV;
00096   nat n= tfter.len;
00097   nat m= next_power_of_two (n);
00098   nat l= aligned_size<F,VV> (m);
00099   for (nat i= 0; i < r; i++)
00100     for (nat j= 0; j < c; j++) {
00101       F* aux= mmx_new<F> (l);
00102       for (nat k= 0; k < n ; k++)
00103         aux[k]= s[k](i,j);
00104       tfter.inverse_transform (aux);
00105       vector<F,VV> tmp (new vector_rep<F,VV> (aux, n, l, false, format<F> ()));
00106       Acc::set_op (d[i * d_rs + j * d_cs], as<C> (tmp)); } }
00107 
00108 template<typename Op, typename C, typename Tfter> static void
00109 mul (C* d, const C* s1, const C* s2,
00110      nat r, nat rr, nat l, nat ll, nat c, nat cc, Tfter& tfter) {
00111   typedef Scalar_type (C) F;
00112   typedef matrix<F> Matrix_F;
00113   nat n= tfter.len;
00114   Matrix_F* mm1= mmx_new<Matrix_F> (n);
00115   Matrix_F* mm2= mmx_new<Matrix_F> (n);
00116   Matrix_F* mmd= mmx_new<Matrix_F> (n);
00117   mat_direct_tft (mm1, s1, Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll),
00118                   r, l, tfter);
00119   mat_direct_tft (mm2, s2, Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc),
00120                   l, c, tfter);
00121   for (nat k= 0; k < n; k++)
00122     mmd[k]= mm1[k] * mm2[k];
00123   mat_inverse_tft<Op> (d, Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc),
00124                        r, c, mmd, tfter);
00125   mmx_delete<Matrix_F> (mm1, n);
00126   mmx_delete<Matrix_F> (mm2, n);
00127   mmx_delete<Matrix_F> (mmd, n); }
00128 
00129 template<typename Op, typename C> static void
00130 mul (C* d, const C* s1, const C* s2,
00131      nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00132   typedef matrix_tft_multiply_helper<C> Matrix_tft;
00133   typedef typename Matrix_tft::tft_transformer Tfter;
00134   typedef typename Matrix_tft::F F;
00135   typedef typename Matrix_variant(F) MVF;
00136   typedef implementation<matrix_multiply,MVF> MatF;
00137   nat sz= matrix_tft_multiply_helper<C>
00138     ::size (s1, Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll),
00139             s2, Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc),
00140             r, l, c);
00141   if (sz == 0)
00142     Mat::template clr<Op> (d, r, rr, c, cc);
00143   else if (sz == 1) {
00144     nat l1= aligned_size<F,MVF> (r * l);
00145     nat l2= aligned_size<F,MVF> (l * c);
00146     nat ld= aligned_size<F,MVF> (r * c);
00147     F* m1= mmx_new<F> (l1);
00148     F* m2= mmx_new<F> (l2);
00149     F* md= mmx_new<F> (ld);
00150     Mat::template mat_binary_scalar_stride<access_op>
00151       (m1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00152        s1, Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll), 0, r, l);
00153     Mat::template mat_binary_scalar_stride<access_op>
00154       (m2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00155        s2, Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc), 0, l, c);
00156     Mat::template mat_binary_scalar_stride<access_op>
00157       (md, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c),
00158        d, Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc), 0, r, c);
00159     MatF::template mul<Op> (md, m1, m2, r, r, l, l, c, c);
00160     Mat::template mat_unary_stride<id_op>
00161       (d, Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc),
00162        md, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c), r, c);
00163     mmx_delete<F> (m1, l1);
00164     mmx_delete<F> (m2, l2);
00165     mmx_delete<F> (md, ld);
00166   }
00167   else {
00168     Tfter tfter (sz, format<F> ());
00169     mul<Op> (d, s1, s2, r, rr, l, ll, c, cc, tfter); } }
00170 
00171 template<typename C> static void
00172 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c) {
00173   mul<mul_op> (d, s1, s2, r, r, l, l, c, c); }
00174 
00175 template<typename Op, typename D, typename S1, typename S2>
00176 static inline void
00177 mul (D* d, const S1* s1, const S2* s2,
00178      nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00179   Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc); }
00180 
00181 template<typename D, typename S1, typename S2>
00182 static inline void
00183 mul (D* d, const S1* s1, const S2* s2,
00184      nat r, nat l, nat c) {
00185   Mat::template mul<mul_op> (d, s1, s2, r, r, l, l, c, c); }
00186 
00187 }; 
00188 
00189 } 
00190 #endif // __MMX_MATRIX_TFT_HPP