00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__FFT_NAIVE__HPP
00014 #define __MMX__FFT_NAIVE__HPP
00015 #include <algebramix/fft_roots.hpp>
00016
00017 namespace mmx {
00018
00019
00020
00021
00022
00023 template<typename CC, typename UU, typename SS>
00024 struct roots_helper {
00025 typedef CC C;
00026 typedef UU U;
00027 typedef SS S;
00028
00029 static U*
00030 create_roots (nat n, const format<U>& fm) {
00031 nat k= primitive_root_max_order<U> (2); (void) k;
00032 VERIFY (k == 0 || n <= k, "maximum order exceeded");
00033 VERIFY (n >= 2, "size must be at least two");
00034 U* roots= mmx_new<U> (n);
00035 for (nat i=0; i<n; i+=2) {
00036 roots[i] = primitive_root<U> (n, bit_mirror (i, n), fm);
00037 roots[i+1]= primitive_root<U> (n, i==0? 0: n - bit_mirror (i, n), fm);
00038 }
00039 return roots; }
00040
00041 static void
00042 destroy_roots (U* u, nat n) {
00043 mmx_delete<U> (u, n); }
00044
00045 static inline void
00046 fft_cross (C* c1, C* c2) {
00047 C temp= (*c2);
00048 *c2 = (*c1) - temp;
00049 *c1 = (*c1) + temp; }
00050
00051 static inline void
00052 dfft_cross (C* c1, C* c2, const U* u) {
00053 C temp= (*u) * (*c2);
00054 *c2 = (*c1) - temp;
00055 *c1 = (*c1) + temp; }
00056
00057 static inline void
00058 ifft_cross (C* c1, C* c2, const U* u) {
00059 C temp= *c2;
00060 *c2 = (*u ) * ((*c1) - temp);
00061 *c1 = (*c1) + temp; }
00062
00063 static inline void
00064 dtft_cross (C* c1, C* c2) {
00065 static S h= invert (S (2));
00066 fft_cross (c1, c2);
00067 *c1 *= h;
00068 *c2 *= h; }
00069
00070 static inline void
00071 dtft_cross (C* c1, C* c2, const U* u) {
00072 static S h= invert (S (2));
00073 dfft_cross (c1, c2, u);
00074 *c1 *= h;
00075 *c2 *= h; }
00076
00077 static inline void
00078 itft_flip (C* c1, C* c2, const U* u) {
00079 static S h= invert (S(2));
00080 C temp= (*u) * (*c2);
00081 *c1 += (*c1) - temp;
00082 *c2 = h * ((*c1) - temp); }
00083
00084 static inline void
00085 itft_flip (C* c1, C* c2) {
00086 static S h= invert (S(2));
00087 *c1 += (*c1) - (*c2);
00088 *c2 = h * ((*c1) - (*c2)); }
00089
00090 struct fft_mul_sc_op : mul_op {};
00091 };
00092
00093
00094
00095
00096
00097 template<typename C, typename V= std_roots<C> >
00098 class fft_naive_transformer {
00099 public:
00100 typedef implementation<vector_linear,vector_naive> NVec;
00101 typedef typename V::roots_type R;
00102 typedef typename R::U U;
00103 typedef typename R::S S;
00104
00105 format<C> fm;
00106 nat depth;
00107 nat len;
00108 U* roots;
00109
00110 public:
00111 inline fft_naive_transformer (nat n, const format<C>& fm2):
00112 fm (fm2), depth (log_2 (n)), len (n),
00113 roots (R::create_roots (n, format<U> (no_format ()))) {
00114 VERIFY (n == ((nat) 1 << depth), "power of two expected"); }
00115
00116 inline ~fft_naive_transformer () {
00117 R::destroy_roots (roots, len); }
00118
00119 inline void
00120 dfft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00121
00122
00123
00124 for (nat step= step1; step < step2; step++) {
00125
00126 if (step == 0 && shift == 0) {
00127 nat todo= steps - 1;
00128 C* cc= c;
00129 for (nat k= 0; k < ((nat) 1<<todo); k++) {
00130 R::fft_cross (cc, cc + (stride<<todo));
00131 cc += stride;
00132 }
00133 }
00134 else {
00135 nat todo= steps - 1 - step;
00136 C* cc= c;
00137 U * uu= roots + ((shift >> todo) << 1);
00138 for (nat j= 0; j < ((nat) 1<<step); j++) {
00139 for (nat k= 0; k < ((nat) 1<<todo); k++) {
00140 R::dfft_cross (cc, cc + (stride<<todo), uu);
00141 cc += stride;
00142 }
00143 cc += (stride<<todo);
00144 uu += 2;
00145 }
00146 }
00147 }
00148 }
00149
00150 inline void
00151 ifft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00152
00153
00154
00155 for (int step= step2-1; (int) step >= ((int) step1); step--) {
00156
00157 if (step == 0 && shift == 0) {
00158 nat todo= steps - 1;
00159 C* cc= c;
00160 for (nat k= 0; k < ((nat) 1<<todo); k++) {
00161 R::fft_cross (cc, cc + (stride<<todo));
00162 cc += stride;
00163 }
00164 }
00165 else {
00166 nat todo= steps - 1 - step;
00167 C* cc= c;
00168 U * uu= roots + 1 + ((shift >> todo) << 1);
00169 for (nat j= 0; j < ((nat) 1<<step); j++) {
00170 for (nat k= 0; k < ((nat) 1<<todo); k++) {
00171 R::ifft_cross (cc, cc + (stride<<todo), uu);
00172 cc += stride;
00173 }
00174 cc += (stride<<todo);
00175 uu += 2;
00176 }
00177 }
00178 }
00179 }
00180
00181 inline void
00182 dfft (C* c, nat stride, nat shift, nat steps) {
00183 dfft (c, stride, shift, steps, 0, steps); }
00184
00185 inline void
00186 ifft (C* c, nat stride, nat shift, nat steps) {
00187 ifft (c, stride, shift, steps, 0, steps); }
00188
00189 inline void
00190 direct_transform (C* c) {
00191 dfft (c, 1, 0, depth); }
00192
00193 inline void
00194 inverse_transform (C* c, bool divide=true) {
00195 ifft (c, 1, 0, depth);
00196 if (divide) {
00197 S x= binpow (S (2), depth);
00198 x= invert (x);
00199 NVec::template vec_unary_scalar<typename R::fft_mul_sc_op> (c, x, len);
00200 }
00201 }
00202 };
00203
00204
00205
00206
00207
00208 template<typename C> inline void
00209 direct_fft (C* dest, nat n) {
00210 if (n == 0) return;
00211 fft_naive_transformer<C> ffter (n, get_format (dest[0]));
00212 ffter.direct_transform (dest);
00213 }
00214
00215 template<typename C> inline void
00216 inverse_fft (C* dest, nat n) {
00217 if (n == 0) return;
00218 fft_naive_transformer<C> ffter (n, get_format (dest[0]));
00219 ffter.inverse_transform (dest);
00220 }
00221
00222 }
00223 #endif //__MMX__FFT_NAIVE__HPP