00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX_MATRIX_TFT_VECTORIZED_HPP
00014 #define __MMX_MATRIX_TFT_VECTORIZED_HPP
00015 #include <algebramix/matrix_tft.hpp>
00016 #include <algebramix/fft_truncated.hpp>
00017 #include <algebramix/vector_simd.hpp>
00018
00019 namespace mmx {
00020
00021
00022
00023
00024
00025
00026
00027
00028 template<typename V>
00029 struct matrix_tft_vectorized: public V {
00030 typedef typename V::Vec Vec;
00031 typedef typename V::Naive Naive;
00032 typedef typename V::Positive Positive;
00033 typedef matrix_tft<typename V::No_simd> No_simd;
00034 typedef matrix_tft_vectorized<typename V::No_thread> No_thread;
00035 typedef matrix_tft_vectorized<typename V::No_scaled> No_scaled;
00036 };
00037
00038 template<typename F, typename V, typename W>
00039 struct implementation<F,V,matrix_tft_vectorized<W> >:
00040 public implementation<F,V,W> {};
00041
00042
00043
00044
00045
00046 namespace _matrix_tft_vectorized {
00047
00048 template<typename C> struct unptr_helper {};
00049 template<typename C> struct unptr_helper<C*> { typedef C type; };
00050
00051 template<typename Cp>
00052 struct vectorized_roots_helper {
00053
00054 typedef typename unptr_helper<Cp>::type CC;
00055 typedef Cp C;
00056 typedef CC U;
00057 typedef CC S;
00058 typedef typename Vector_simd_variant(CC) VV;
00059 typedef implementation<vector_linear,VV> Vec;
00060
00061
00062 static inline nat&
00063 get_size () { static nat h= 0; return h; }
00064
00065 static inline nat&
00066 get_allocated_size () { static nat l= 0; return l; }
00067
00068 static inline C&
00069 get_temp () { static C temp= NULL; return temp; }
00070
00071 static U*
00072 create_roots (nat n, const format<U>& fm) {
00073 nat k= primitive_root_max_order<U> (2); (void) k;
00074 VERIFY (k == 0 || n <= k, "maximum order exceeded");
00075 VERIFY (n >= 2, "size must be at least two");
00076 nat h= get_size ();
00077 nat l= get_allocated_size ()= aligned_size<CC,VV> (h);
00078 get_temp ()= mmx_new<CC> (l);
00079 U* roots= mmx_new<U> (n);
00080 for (nat i=0; i<n; i+=2) {
00081 roots[i] = primitive_root<U> (n, bit_mirror (i, n), fm);
00082 roots[i+1]= primitive_root<U> (n, i==0 ? 0: n - bit_mirror (i, n), fm);
00083 }
00084 return roots; }
00085
00086 static void
00087 destroy_roots (U* u, nat n) {
00088 C& temp= get_temp ();
00089 if (temp != NULL) {
00090 mmx_delete<CC> (temp, get_allocated_size ());
00091 temp= NULL;
00092 }
00093 mmx_delete<U> (u, n); }
00094
00095 static inline void
00096 fft_cross (C* c1, C* c2) {
00097 C temp= get_temp ();
00098 nat h= get_size ();
00099 Vec::copy (temp, *c2, h);
00100 Vec::sub (*c2, *c1, temp, h);
00101 Vec::add (*c1, temp, h); }
00102
00103 static inline void
00104 dfft_cross (C* c1, C* c2, const U* u) {
00105 C temp= get_temp ();
00106 nat h= get_size ();
00107 Vec::mul (temp, *c2, *u, h);
00108 Vec::sub (*c2, *c1, temp, h);
00109 Vec::add (*c1, temp, h); }
00110
00111 static inline void
00112 ifft_cross (C* c1, C* c2, const U* u) {
00113 C temp= get_temp ();
00114 nat h= get_size ();
00115 Vec::copy (temp, *c2, h);
00116 Vec::sub (*c2, *c1, temp, h);
00117 Vec::mul (*c2, *u, h);
00118 Vec::add (*c1, temp, h); }
00119
00120 static inline void
00121 dtft_cross (C* c1, C* c2) {
00122 static S a= invert (S(2));
00123 nat h= get_size ();
00124 fft_cross (c1, c2);
00125 Vec::mul (*c1, a, h);
00126 Vec::mul (*c2, a, h); }
00127
00128 static inline void
00129 dtft_cross (C* c1, C* c2, const U* u) {
00130 static S a= invert (S(2));
00131 nat h= get_size ();
00132 dfft_cross (c1, c2, u);
00133 Vec::mul (*c1, a, h);
00134 Vec::mul (*c2, a, h); }
00135
00136 static inline void
00137 itft_flip (C* c1, C* c2, const U* u) {
00138 static S a= invert (S(2));
00139 C temp= get_temp ();
00140 nat h= get_size ();
00141 Vec::mul (temp, *c2, *u, h);
00142 Vec::add (*c2, *c1, *c1, h);
00143 Vec::sub (*c1, *c2, temp, h);
00144 Vec::sub (*c2, *c1, temp, h);
00145 Vec::mul (*c2, a, h); }
00146
00147 static inline void
00148 itft_flip (C* c1, C* c2) {
00149 static S a= invert (S(2));
00150 C temp= get_temp ();
00151 nat h= get_size ();
00152 Vec::add (temp, *c1, *c1, h);
00153 Vec::sub (*c1, temp, *c2, h);
00154 Vec::sub (*c2, *c1, *c2, h);
00155 Vec::mul (*c2, a, h); }
00156
00157 struct fft_mul_sc_op : mul_op {
00158 static inline void
00159 set_op (C& x, const S& y) {
00160 Vec::mul (x, y, get_size ()); }
00161 };
00162 };
00163
00164 template<typename C>
00165 struct vectorized_roots {
00166 typedef vectorized_roots_helper<C> roots_type;
00167 };
00168
00169 }
00170
00171
00172
00173
00174
00175 template<typename C>
00176 struct matrix_tft_vectorized_multiply_helper {
00177 typedef Scalar_type (C) F;
00178 typedef typename Vector_simd_variant(F) VF;
00179 typedef _matrix_tft_vectorized::vectorized_roots<F*> _vectorized_roots;
00180 typedef fft_naive_transformer<F*,_vectorized_roots> _fft_transformer;
00181 typedef fft_truncated_transformer<F*,_fft_transformer> tft_transformer;
00182 static nat size (const C* s1, nat s1_rs, nat s1_cs,
00183 const C* s2, nat s2_rs, nat s2_cs,
00184 nat r, nat l, nat c) {
00185 nat sz= 0;
00186 for (nat k= 0; k < l; k++) {
00187 nat sz1= 0, sz2= 0;
00188 const C* ss1= s1 + k * s1_cs;
00189 const C* ss2= s2 + k * s2_rs;
00190 for (nat i= 0; i < r; i++, ss1 += s1_rs) sz1= max (sz1, N (*ss1));
00191 for (nat j= 0; j < c; j++, ss2 += s2_cs) sz2= max (sz2, N (*ss2));
00192 sz= max (sz, (sz1 == 0 || sz2 == 0) ? 0 : (sz1 + sz2 - 1));
00193 }
00194 return sz; }
00195 };
00196
00197
00198
00199
00200
00201 template<typename V, typename W>
00202 struct implementation<matrix_multiply,V,matrix_tft_vectorized<W> >:
00203 public implementation<matrix_multiply_base,V>
00204 {
00205 typedef implementation<matrix_multiply,W> Mat;
00206
00207 template<typename C, typename F, typename Tfter>
00208 static void
00209 mat_direct_tft (F** dest, const C* s,
00210 nat s_rs, nat s_cs, nat r, nat c, Tfter& tfter) {
00211 nat n= tfter.len;
00212 for (nat k= 0; k < n; k++) {
00213 F* d= dest[k];
00214 for (nat j= 0; j < c; j++)
00215 for (nat i= 0; i < r; i++, d++)
00216 *d= s[i * s_rs + j * s_cs][k];
00217 }
00218 Tfter::R::get_size ()= r * c;
00219 tfter.dtft (dest, 1, n, 0); }
00220
00221 template<typename Op, typename C, typename F, typename Tfter>
00222 static void
00223 mat_inverse_tft (C* d, nat d_rs, nat d_cs, nat r, nat c,
00224 F** s, Tfter& tfter) {
00225 typedef typename Op::nomul_op Acc;
00226 typedef typename Vector_simd_variant(F) VF;
00227 typedef implementation<vector_linear,VF> Vec;
00228 nat n= tfter.len;
00229 nat m= next_power_of_two (n);
00230 Tfter::R::get_size ()= r * c;
00231 tfter.itft (s, 1, n, 0);
00232 F a= binpow (invert (F(2)), log_2 (m));
00233 for (nat k= 0; k < n; k++)
00234 Vec::mul (s[k], a, r * c);
00235 for (nat i= 0; i < r; i++)
00236 for (nat j= 0; j < c; j++) {
00237 vector<F> tmp (F (0), n);
00238 for (nat k= 0; k < n ; k++)
00239 tmp[k]= s[k][j * r + i];
00240 Acc::set_op (d[i * d_rs + j * d_cs], as<C> (tmp)); } }
00241
00242 template<typename Op, typename C, typename Tfter> static void
00243 mul (C* d, const C* s1, const C* s2,
00244 nat r, nat rr, nat l, nat ll, nat c, nat cc,
00245 Tfter& tfter) {
00246 typedef Scalar_type (C) F;
00247 typedef typename Vector_simd_variant(F) VF;
00248 typedef typename Matrix_variant(F) MVF;
00249 typedef implementation<matrix_multiply,MVF> MatF;
00250 nat n= tfter.len;
00251 nat m= next_power_of_two (n);
00252 nat hrl= r * l, hlc= l * c, hrc= r * c;
00253 nat lrl= aligned_size<F,VF> (hrl);
00254 nat llc= aligned_size<F,VF> (hlc);
00255 nat lrc= aligned_size<F,VF> (hrc);
00256 F** mm1= mmx_new<F*> (m);
00257 F** mm2= mmx_new<F*> (m);
00258 F** mmd= mmx_new<F*> (m);
00259 for (nat i= 0; i < m; i++) {
00260 mm1[i]= mmx_new<F> (lrl, F(0));
00261 mm2[i]= mmx_new<F> (llc, F(0));
00262 mmd[i]= mmx_new<F> (lrc, F(0));
00263 }
00264 mat_direct_tft (mm1, s1, Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll),
00265 r, l, tfter);
00266 mat_direct_tft (mm2, s2, Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc),
00267 l, c, tfter);
00268 for (nat k= 0; k < n; k++)
00269 MatF::template mul<Op> (mmd[k], mm1[k], mm2[k], r, r, l, l, c, c);
00270 mat_inverse_tft<Op> (d, Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc),
00271 r, c, mmd, tfter);
00272 for (nat i= 0; i < m; i++) {
00273 mmx_delete<F> (mm1[i], lrl);
00274 mmx_delete<F> (mm2[i], llc);
00275 mmx_delete<F> (mmd[i], lrc);
00276 }
00277 mmx_delete<F*> (mm1, m);
00278 mmx_delete<F*> (mm2, m);
00279 mmx_delete<F*> (mmd, m); }
00280
00281 template<typename Op, typename C> static void
00282 mul (C* d, const C* s1, const C* s2,
00283 nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00284 typedef matrix_tft_vectorized_multiply_helper<C> Matrix_tft;
00285 typedef typename Matrix_tft::tft_transformer Tfter;
00286 typedef typename Matrix_tft::F F;
00287 typedef typename Matrix_variant(F) MVF;
00288 typedef implementation<matrix_multiply,MVF> MatF;
00289 nat sz= matrix_tft_multiply_helper<C>
00290 ::size (s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00291 s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00292 r, l, c);
00293 if (sz == 0)
00294 Mat::template clr<Op> (d, r, rr, c, cc);
00295 else if (sz == 1) {
00296 nat l1= aligned_size<F,MVF> (r * l);
00297 nat l2= aligned_size<F,MVF> (l * c);
00298 nat ld= aligned_size<F,MVF> (r * c);
00299 F* m1= mmx_new<F> (l1);
00300 F* m2= mmx_new<F> (l2);
00301 F* md= mmx_new<F> (ld);
00302 Mat::template mat_binary_scalar_stride<access_op>
00303 (m1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00304 s1, Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll), 0, r, l);
00305 Mat::template mat_binary_scalar_stride<access_op>
00306 (m2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00307 s2, Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc), 0, l, c);
00308 Mat::template mat_binary_scalar_stride<access_op>
00309 (md, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c),
00310 d, Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc), 0, r, c);
00311 MatF::template mul<Op> (md, m1, m2, r, r, l, l, c, c);
00312 Mat::template mat_unary_stride<id_op>
00313 (d, Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc),
00314 md, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c), r, c);
00315 mmx_delete<F> (m1, l1);
00316 mmx_delete<F> (m2, l2);
00317 mmx_delete<F> (md, ld);
00318 }
00319 else {
00320 Tfter::R::get_size ()= max (max (r * l, l * c), r * c);
00321 Tfter tfter (sz, format<C*> ());
00322 mul<Op> (d, s1, s2, r, rr, l, ll, c, cc, tfter); } }
00323
00324 template<typename C> static void
00325 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c) {
00326 mul<mul_op> (d, s1, s2, r, r, l, l, c, c); }
00327
00328 template<typename Op, typename D, typename S1, typename S2>
00329 static inline void
00330 mul (D* d, const S1* s1, const S2* s2,
00331 nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00332 Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc); }
00333
00334 template<typename D, typename S1, typename S2>
00335 static inline void
00336 mul (D* d, const S1* s1, const S2* s2,
00337 nat r, nat l, nat c) {
00338 Mat::template mul<mul_op> (d, s1, s2, r, r, l, l, c, c); }
00339
00340 };
00341
00342 }
00343 #endif // __MMX_MATRIX_TFT_VECTORIZED