00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__FFT_NAIVE__HPP
00014 #define __MMX__FFT_NAIVE__HPP
00015 #include <algebramix/fft_roots.hpp>
00016
00017 namespace mmx {
00018
00019
00020
00021
00022
00023 template<typename CC, typename UU, typename SS>
00024 struct roots_helper {
00025 typedef CC C;
00026 typedef UU U;
00027 typedef SS S;
00028
00029 static U*
00030 create_roots (nat n, const format<C>& fm) {
00031 nat k= primitive_root_max_order<C> (2); (void) k;
00032 VERIFY (k == 0 || n <= k, "maximum order exceeded");
00033 VERIFY (n >= 2, "size must be at least two");
00034 U* roots= mmx_new<U> (n);
00035 for (nat i=0; i<n; i+=2) {
00036 U temp = primitive_root<C> (n, bit_mirror (i, n), fm);
00037 roots[i] = temp;
00038 roots[i+1]= primitive_root<C> (n, i==0? 0: n - bit_mirror (i, n), fm);
00039 }
00040 return roots; }
00041
00042 static void
00043 destroy_roots (U* u, nat n) {
00044 mmx_delete<U> (u, n); }
00045
00046 static inline void
00047 fft_cross (C* c1, C* c2) {
00048 C temp= (*c2);
00049 *c2 = (*c1) - temp;
00050 *c1 = (*c1) + temp; }
00051
00052 static inline void
00053 dfft_cross (C* c1, C* c2, const U* u) {
00054 C temp= (*u) * (*c2);
00055 *c2 = (*c1) - temp;
00056 *c1 = (*c1) + temp; }
00057
00058 static inline void
00059 ifft_cross (C* c1, C* c2, const U* u) {
00060 C temp= *c2;
00061 *c2 = (*u ) * ((*c1) - temp);
00062 *c1 = (*c1) + temp; }
00063
00064 static inline void
00065 dtft_cross (C* c1, C* c2) {
00066 static S h= invert (S (2));
00067 fft_cross (c1, c2);
00068 *c1 *= h;
00069 *c2 *= h; }
00070
00071 static inline void
00072 dtft_cross (C* c1, C* c2, const U* u) {
00073 static S h= invert (S (2));
00074 dfft_cross (c1, c2, u);
00075 *c1 *= h;
00076 *c2 *= h; }
00077
00078 static inline void
00079 itft_flip (C* c1, C* c2, const U* u) {
00080 static S h= invert (S(2));
00081 C temp= (*u) * (*c2);
00082 *c1 += (*c1) - temp;
00083 *c2 = h * ((*c1) - temp); }
00084
00085 static inline void
00086 itft_flip (C* c1, C* c2) {
00087 static S h= invert (S(2));
00088 *c1 += (*c1) - (*c2);
00089 *c2 = h * ((*c1) - (*c2)); }
00090
00091 struct fft_mul_sc_op : mul_op {};
00092 };
00093
00094
00095
00096
00097
00098 template<typename C, typename V= std_roots<C> >
00099 class fft_naive_transformer {
00100 public:
00101 typedef implementation<vector_linear,vector_naive> NVec;
00102 typedef typename V::roots_type R;
00103 typedef typename R::U U;
00104 typedef typename R::S S;
00105
00106 format<C> fm;
00107 nat depth;
00108 nat len;
00109 U* roots;
00110
00111 public:
00112 inline fft_naive_transformer (nat n, const format<C>& fm2):
00113 fm (fm2),
00114 depth (log_2 (n)), len (n), roots (R::create_roots (n, fm)) {
00115 VERIFY (n == ((nat) 1 << depth), "power of two expected"); }
00116
00117 inline ~fft_naive_transformer () {
00118 R::destroy_roots (roots, len); }
00119
00120 inline void
00121 dfft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00122
00123
00124
00125 for (nat step= step1; step < step2; step++) {
00126
00127 if (step == 0 && shift == 0) {
00128 nat todo= steps - 1;
00129 C* cc= c;
00130 for (nat k= 0; k < ((nat) 1<<todo); k++) {
00131 R::fft_cross (cc, cc + (stride<<todo));
00132 cc += stride;
00133 }
00134 }
00135 else {
00136 nat todo= steps - 1 - step;
00137 C* cc= c;
00138 U * uu= roots + ((shift >> todo) << 1);
00139 for (nat j= 0; j < ((nat) 1<<step); j++) {
00140 for (nat k= 0; k < ((nat) 1<<todo); k++) {
00141 R::dfft_cross (cc, cc + (stride<<todo), uu);
00142 cc += stride;
00143 }
00144 cc += (stride<<todo);
00145 uu += 2;
00146 }
00147 }
00148 }
00149 }
00150
00151 inline void
00152 ifft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00153
00154
00155
00156 for (int step= step2-1; (int) step >= ((int) step1); step--) {
00157
00158 if (step == 0 && shift == 0) {
00159 nat todo= steps - 1;
00160 C* cc= c;
00161 for (nat k= 0; k < ((nat) 1<<todo); k++) {
00162 R::fft_cross (cc, cc + (stride<<todo));
00163 cc += stride;
00164 }
00165 }
00166 else {
00167 nat todo= steps - 1 - step;
00168 C* cc= c;
00169 U * uu= roots + 1 + ((shift >> todo) << 1);
00170 for (nat j= 0; j < ((nat) 1<<step); j++) {
00171 for (nat k= 0; k < ((nat) 1<<todo); k++) {
00172 R::ifft_cross (cc, cc + (stride<<todo), uu);
00173 cc += stride;
00174 }
00175 cc += (stride<<todo);
00176 uu += 2;
00177 }
00178 }
00179 }
00180 }
00181
00182 inline void
00183 dfft (C* c, nat stride, nat shift, nat steps) {
00184 dfft (c, stride, shift, steps, 0, steps); }
00185
00186 inline void
00187 ifft (C* c, nat stride, nat shift, nat steps) {
00188 ifft (c, stride, shift, steps, 0, steps); }
00189
00190 inline void
00191 direct_transform (C* c) {
00192 dfft (c, 1, 0, depth); }
00193
00194 inline void
00195 inverse_transform (C* c, bool divide=true) {
00196 ifft (c, 1, 0, depth);
00197 if (divide) {
00198 S x= binpow (S (2), depth);
00199 x= invert (x);
00200 NVec::template vec_unary_scalar<typename R::fft_mul_sc_op> (c, x, len);
00201 }
00202 }
00203 };
00204
00205
00206
00207
00208
00209 template<typename C> inline void
00210 direct_fft (C* dest, nat n) {
00211 if (n == 0) return;
00212 fft_naive_transformer<C> ffter (n, get_format (dest[0]));
00213 ffter.direct_transform (dest);
00214 }
00215
00216 template<typename C> inline void
00217 inverse_fft (C* dest, nat n) {
00218 if (n == 0) return;
00219 fft_naive_transformer<C> ffter (n, get_format (dest[0]));
00220 ffter.inverse_transform (dest);
00221 }
00222
00223 }
00224 #endif //__MMX__FFT_NAIVE__HPP