00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX_MATRIX_TFT_HPP
00014 #define __MMX_MATRIX_TFT_HPP
00015 #include <algebramix/matrix.hpp>
00016 #include <algebramix/fft_truncated.hpp>
00017 namespace mmx {
00018
00019
00020
00021
00022
00023 template<typename V>
00024 struct matrix_tft: public V {
00025 typedef typename V::Vec Vec;
00026 typedef typename V::Naive Naive;
00027 typedef typename V::Positive Positive;
00028 typedef matrix_tft<typename V::No_simd> No_simd;
00029 typedef matrix_tft<typename V::No_thread> No_thread;
00030 typedef matrix_tft<typename V::No_scaled> No_scaled;
00031 };
00032
00033 template<typename F, typename V, typename W>
00034 struct implementation<F,V,matrix_tft<W> >:
00035 public implementation<F,V,W> {};
00036
00037
00038
00039
00040
00041 template<typename C>
00042 struct matrix_tft_multiply_helper {
00043
00044 typedef Scalar_type (C) F;
00045 typedef fft_truncated_transformer<F> tft_transformer;
00046 static nat size (const C* s1, nat s1_rs, nat s1_cs,
00047 const C* s2, nat s2_rs, nat s2_cs,
00048 nat r, nat l, nat c) {
00049 nat sz= 0;
00050 for (nat k= 0; k < l; k++) {
00051 nat sz1= 0, sz2= 0;
00052 const C* ss1= s1 + k * s1_cs;
00053 const C* ss2= s2 + k * s2_rs;
00054 for (nat i= 0; i < r; i++, ss1 += s1_rs) sz1= max (sz1, N (*ss1));
00055 for (nat j= 0; j < c; j++, ss2 += s2_cs) sz2= max (sz2, N (*ss2));
00056 sz= max (sz, (sz1 == 0 || sz2 == 0) ? 0 : (sz1 + sz2 - 1));
00057 }
00058 return sz; }
00059 };
00060
00061
00062
00063
00064
00065 template<typename V, typename W>
00066 struct implementation<matrix_multiply_base,V,matrix_tft<W> >:
00067 public implementation<matrix_linear,V>
00068 {
00069 typedef implementation<matrix_multiply,W> Mat;
00070
00071 template<typename C, typename F, typename MV, typename Tfter>
00072 static void
00073 mat_direct_tft (matrix<F,MV>* dest, const C* s,
00074 nat s_rs, nat s_cs, nat r, nat c, Tfter& tfter) {
00075 nat n= tfter.len;
00076 nat m= next_power_of_two (n);
00077 for (nat k= 0; k < n; k++)
00078 dest[k]= matrix<F,MV> (F (), r, c);
00079 F* aux= mmx_new<F> (m);
00080 for (nat i= 0; i < r; i++)
00081 for (nat j= 0; j < c; j++) {
00082 for (nat k= 0; k < n ; k++)
00083 aux[k]= s[i * s_rs + j * s_cs] [k];
00084 tfter.direct_transform (aux);
00085 for (nat k= 0; k < n ; k++)
00086 dest[k](i,j)= aux[k];
00087 }
00088 mmx_delete<F> (aux, m); }
00089
00090 template<typename Op,typename C, typename F, typename MV, typename Tfter>
00091 static void
00092 mat_inverse_tft (C* d, nat d_rs, nat d_cs, nat r, nat c,
00093 const matrix<F,MV>* s, Tfter& tfter) {
00094 typedef typename Op::nomul_op Acc;
00095 typedef typename Vector_variant(F) VV;
00096 nat n= tfter.len;
00097 nat m= next_power_of_two (n);
00098 nat l= aligned_size<F,VV> (m);
00099 for (nat i= 0; i < r; i++)
00100 for (nat j= 0; j < c; j++) {
00101 F* aux= mmx_new<F> (l);
00102 for (nat k= 0; k < n ; k++)
00103 aux[k]= s[k](i,j);
00104 tfter.inverse_transform (aux);
00105 vector<F,VV> tmp (new vector_rep<F,VV> (aux, n, l, false, format<F> ()));
00106 Acc::set_op (d[i * d_rs + j * d_cs], as<C> (tmp)); } }
00107
00108 template<typename Op, typename C, typename Tfter> static void
00109 mul (C* d, const C* s1, const C* s2,
00110 nat r, nat rr, nat l, nat ll, nat c, nat cc, Tfter& tfter) {
00111 typedef Scalar_type (C) F;
00112 typedef matrix<F> Matrix_F;
00113 nat n= tfter.len;
00114 Matrix_F* mm1= mmx_new<Matrix_F> (n);
00115 Matrix_F* mm2= mmx_new<Matrix_F> (n);
00116 Matrix_F* mmd= mmx_new<Matrix_F> (n);
00117 mat_direct_tft (mm1, s1, Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll),
00118 r, l, tfter);
00119 mat_direct_tft (mm2, s2, Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc),
00120 l, c, tfter);
00121 for (nat k= 0; k < n; k++)
00122 mmd[k]= mm1[k] * mm2[k];
00123 mat_inverse_tft<Op> (d, Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc),
00124 r, c, mmd, tfter);
00125 mmx_delete<Matrix_F> (mm1, n);
00126 mmx_delete<Matrix_F> (mm2, n);
00127 mmx_delete<Matrix_F> (mmd, n); }
00128
00129 template<typename Op, typename C> static void
00130 mul (C* d, const C* s1, const C* s2,
00131 nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00132 typedef matrix_tft_multiply_helper<C> Matrix_tft;
00133 typedef typename Matrix_tft::tft_transformer Tfter;
00134 typedef typename Matrix_tft::F F;
00135 typedef typename Matrix_variant(F) MVF;
00136 typedef implementation<matrix_multiply,MVF> MatF;
00137 nat sz= matrix_tft_multiply_helper<C>
00138 ::size (s1, Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll),
00139 s2, Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc),
00140 r, l, c);
00141 if (sz == 0)
00142 Mat::template clr<Op> (d, r, rr, c, cc);
00143 else if (sz == 1) {
00144 nat l1= aligned_size<F,MVF> (r * l);
00145 nat l2= aligned_size<F,MVF> (l * c);
00146 nat ld= aligned_size<F,MVF> (r * c);
00147 F* m1= mmx_new<F> (l1);
00148 F* m2= mmx_new<F> (l2);
00149 F* md= mmx_new<F> (ld);
00150 Mat::template mat_binary_scalar_stride<access_op>
00151 (m1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00152 s1, Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll), 0, r, l);
00153 Mat::template mat_binary_scalar_stride<access_op>
00154 (m2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00155 s2, Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc), 0, l, c);
00156 Mat::template mat_binary_scalar_stride<access_op>
00157 (md, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c),
00158 d, Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc), 0, r, c);
00159 MatF::template mul<Op> (md, m1, m2, r, r, l, l, c, c);
00160 Mat::template mat_unary_stride<id_op>
00161 (d, Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc),
00162 md, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c), r, c);
00163 mmx_delete<F> (m1, l1);
00164 mmx_delete<F> (m2, l2);
00165 mmx_delete<F> (md, ld);
00166 }
00167 else {
00168 Tfter tfter (sz, format<F> ());
00169 mul<Op> (d, s1, s2, r, rr, l, ll, c, cc, tfter); } }
00170
00171 template<typename C> static void
00172 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c) {
00173 mul<mul_op> (d, s1, s2, r, r, l, l, c, c); }
00174
00175 template<typename Op, typename D, typename S1, typename S2>
00176 static inline void
00177 mul (D* d, const S1* s1, const S2* s2,
00178 nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00179 Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc); }
00180
00181 template<typename D, typename S1, typename S2>
00182 static inline void
00183 mul (D* d, const S1* s1, const S2* s2,
00184 nat r, nat l, nat c) {
00185 Mat::template mul<mul_op> (d, s1, s2, r, r, l, l, c, c); }
00186
00187 };
00188
00189 }
00190 #endif // __MMX_MATRIX_TFT_HPP