00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__FFT_SIMD__HPP
00014 #define __MMX__FFT_SIMD__HPP
00015 #include <numerix/simd.hpp>
00016 #include <numerix/sse.hpp>
00017 #include <algebramix/fft_naive.hpp>
00018
00019 namespace mmx {
00020
00021
00022 template<typename C,
00023 typename FFTER= fft_naive_transformer<C>,
00024 typename FFTER_SIMD= fft_naive_transformer<typename Simd_type(C)>,
00025 nat thr= 2>
00026 class fft_simd_transformer {
00027 public:
00028 typedef typename FFTER::R R;
00029 typedef typename R::U U;
00030 typedef typename R::S S;
00031
00032 FFTER* ffter;
00033 nat depth;
00034 nat len;
00035 U* roots;
00036
00037 public:
00038 inline fft_simd_transformer (nat n, const format<C>& fm):
00039 ffter (new FFTER (n, fm)),
00040 depth (ffter->depth), len (ffter->len), roots (ffter->roots) {}
00041
00042 inline ~fft_simd_transformer () { delete ffter; }
00043
00044 template<typename CC> inline void
00045 dfft (CC* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00046 ffter->dfft (c, stride, shift, steps, step1, step2); }
00047
00048 template<typename CC> inline void
00049 ifft (CC* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00050 ffter->ifft (c, stride, shift, steps, step1, step2); }
00051
00052 template<typename CC> inline void
00053 dfft (CC* c, nat stride, nat shift, nat steps) {
00054 dfft (c, stride, shift, steps, 0, steps); }
00055
00056 template<typename CC> inline void
00057 ifft (CC* c, nat stride, nat shift, nat steps) {
00058 ifft (c, stride, shift, steps, 0, steps); }
00059
00060 inline void
00061 direct_transform (C* c) {
00062 dfft (c, 1, 0, depth); }
00063
00064 inline void
00065 inverse_transform (C* c, bool divide=true) {
00066 typedef implementation<vector_linear,vector_naive> NVec;
00067 ifft (c, 1, 0, depth);
00068 if (divide) NVec::mul (c, invert (binpow (S (2), depth)), len); }
00069 };
00070
00071 #ifdef NUMERIX_ENABLE_SIMD
00072
00073
00074
00075
00076
00077 #ifdef __SSE2__
00078
00079 inline void
00080 simd_encode (complex<double>* c, nat n) {
00081 double temp;
00082 double* r= (double*) ((void*) c);
00083 for (; n!=0; r+=8, n-=4) {
00084 temp= r[1];
00085 r[1]= r[2];
00086 r[2]= temp;
00087 temp= r[5];
00088 r[5]= r[6];
00089 r[6]= temp;
00090 }
00091 }
00092
00093 inline void
00094 simd_decode (complex<double>* c, nat n) {
00095 double temp;
00096 double* r= (double*) ((void*) c);
00097 for (; n!=0; r+=8, n-=4) {
00098 temp= r[1];
00099 r[1]= r[2];
00100 r[2]= temp;
00101 temp= r[5];
00102 r[5]= r[6];
00103 r[6]= temp;
00104 }
00105 }
00106
00107 STMPL
00108 struct roots_helper<complex<sse_double> >:
00109 public roots_helper<complex<double> > {
00110
00111 typedef complex<sse_double> C;
00112 typedef complex<double> U;
00113 typedef double S;
00114
00115 static inline void
00116 fft_cross (C* c1, C* c2) {
00117 double* z1= (double*) ((void*) c1);
00118 double* z2= (double*) ((void*) c2);
00119 sse_double re_z1= simd_load_aligned (z1);
00120 sse_double im_z1= simd_load_aligned (z1+2);
00121 sse_double re_z2= simd_load_aligned (z2);
00122 sse_double im_z2= simd_load_aligned (z2+2);
00123 simd_save_aligned (z2 , re_z1 - re_z2);
00124 simd_save_aligned (z2+2, im_z1 - im_z2);
00125 simd_save_aligned (z1 , re_z1 + re_z2);
00126 simd_save_aligned (z1+2, im_z1 + im_z2);
00127 }
00128
00129 static inline void
00130 dfft_cross (C* c1, C* c2, const U* u) {
00131 double* z1= (double*) ((void*) c1);
00132 double* z2= (double*) ((void*) c2);
00133 double* u1= (double*) ((void*) u );
00134 sse_double re_z1= simd_load_aligned (z1);
00135 sse_double im_z1= simd_load_aligned (z1+2);
00136 sse_double re_z2= simd_load_aligned (z2);
00137 sse_double im_z2= simd_load_aligned (z2+2);
00138 sse_double re_u1= simd_load_duplicate (u1);
00139 sse_double im_u1= simd_load_duplicate (u1+1);
00140 sse_double re_u2= re_u1 * re_z2 - im_u1 * im_z2;
00141 sse_double im_u2= re_u1 * im_z2 + im_u1 * re_z2;
00142 simd_save_aligned (z2 , re_z1 - re_u2);
00143 simd_save_aligned (z2+2, im_z1 - im_u2);
00144 simd_save_aligned (z1 , re_z1 + re_u2);
00145 simd_save_aligned (z1+2, im_z1 + im_u2);
00146 }
00147
00148 static inline void
00149 ifft_cross (C* c1, C* c2, const U* u) {
00150 double* z1= (double*) ((void*) c1);
00151 double* z2= (double*) ((void*) c2);
00152 double* u1= (double*) ((void*) u );
00153 sse_double re_z1= simd_load_aligned (z1);
00154 sse_double im_z1= simd_load_aligned (z1+2);
00155 sse_double re_z2= simd_load_aligned (z2);
00156 sse_double im_z2= simd_load_aligned (z2+2);
00157 sse_double re_u1= simd_load_duplicate (u1);
00158 sse_double im_u1= simd_load_duplicate (u1+1);
00159 sse_double re_u2= re_z1 - re_z2;
00160 sse_double im_u2= im_z1 - im_z2;
00161 simd_save_aligned (z2 , re_u1 * re_u2 - im_u1 * im_u2);
00162 simd_save_aligned (z2+2, re_u1 * im_u2 + im_u1 * re_u2);
00163 simd_save_aligned (z1 , re_z1 + re_z2);
00164 simd_save_aligned (z1+2, im_z1 + im_z2);
00165 }
00166 };
00167
00168 STMPL
00169 struct std_roots<complex<sse_double> > {
00170 typedef complex<sse_double> C;
00171 typedef cached_roots_helper<roots_helper<C> > roots_type;
00172 };
00173
00174
00175
00176
00177
00178 template<typename FFTER, typename FFTER_SIMD, nat thr>
00179 class fft_simd_transformer<complex<double>, FFTER, FFTER_SIMD, thr> {
00180 public:
00181 typedef complex<double> C;
00182 typedef typename FFTER::R R;
00183 typedef typename R::U U;
00184 typedef typename R::S S;
00185
00186 FFTER* ffter;
00187 nat depth;
00188 nat len;
00189 U* roots;
00190
00191 typedef complex<sse_double> C_SIMD;
00192 FFTER_SIMD* ffter_simd;
00193
00194 public:
00195 inline fft_simd_transformer (nat n, const format<complex<double> >& fm):
00196 ffter (new FFTER (n, fm)),
00197 depth (ffter->depth), len (ffter->len), roots (ffter->roots),
00198 ffter_simd (new FFTER_SIMD (n, format<C_SIMD> ())) {}
00199
00200 inline ~fft_simd_transformer () { delete ffter; delete ffter_simd; }
00201
00202 inline void
00203 dfft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00204 if (steps <= thr || stride != 1)
00205 ffter->dfft (c, stride, shift, steps, step1, step2);
00206 else {
00207 if (steps == step2) {
00208 simd_encode (c, 1 << steps);
00209 ffter_simd->dfft ((C_SIMD*) ((void*) c), 1, shift>>1,
00210 steps-1, step1, steps-1);
00211 simd_decode (c, 1 << steps);
00212 ffter->dfft (c, 1, shift, steps, steps-1, steps);
00213 }
00214 else {
00215 simd_encode (c, 1 << steps);
00216 ffter_simd->dfft ((C_SIMD*) ((void*) c), 1,
00217 shift>>1, steps-1, step1, step2);
00218 simd_decode (c, 1 << steps);
00219 }
00220 }
00221 }
00222
00223 inline void
00224 ifft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00225 if (steps <= thr || stride != 1)
00226 ffter->ifft (c, stride, shift, steps, step1, step2);
00227 else {
00228 if (steps == step2) {
00229 ffter->ifft (c, 1, shift, steps, steps-1, steps);
00230 simd_encode (c, 1 << steps);
00231 ffter_simd->ifft ((C_SIMD*) ((void*) c), 1, shift>>1,
00232 steps-1, step1, steps-1);
00233 simd_decode (c, 1 << steps);
00234 }
00235 else {
00236 simd_encode (c, 1 << steps);
00237 ffter_simd->ifft ((C_SIMD*) ((void*) c), 1,
00238 shift>>1, steps-1, step1, step2);
00239 simd_decode (c, 1 << steps);
00240 }
00241 }
00242 }
00243
00244 inline void
00245 dfft (C* c, nat stride, nat shift, nat steps) {
00246 dfft (c, stride, shift, steps, 0, steps); }
00247
00248 inline void
00249 ifft (C* c, nat stride, nat shift, nat steps) {
00250 ifft (c, stride, shift, steps, 0, steps); }
00251
00252 inline void
00253 direct_transform (C* c) {
00254 dfft (c, 1, 0, ffter->depth); }
00255
00256 inline void
00257 inverse_transform (C* c, bool divide=true) {
00258 typedef implementation<vector_linear,vector_naive> NVec;
00259 ifft (c, 1, 0, depth);
00260 if (divide) NVec::mul (c, invert (binpow (S (2), depth)), len); }
00261 };
00262
00263 #endif //__SSE2__
00264 #endif // NUMERIX_ENABLE_SIMD
00265 }
00266 #endif //__MMX__FFT_SIMD__HPP