00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__FFT_TRUNCATED__HPP
00014 #define __MMX__FFT_TRUNCATED__HPP
00015 #include <basix/vector_naive.hpp>
00016 #include <algebramix/fft_naive.hpp>
00017
00018 namespace mmx {
00019
00020 #define TMPL template<class C>
00021
00022
00023
00024
00025
00026 template<typename C, typename Ffter= fft_naive_transformer<C> >
00027 class fft_truncated_transformer {
00028 public:
00029 typedef implementation<vector_linear,vector_naive> NVec;
00030 typedef typename Ffter::R R;
00031 typedef typename Ffter::U U;
00032 typedef typename Ffter::S S;
00033 nat len;
00034 Ffter* ffter;
00035
00036 public:
00037 inline fft_truncated_transformer (nat s, const format<C>& fm): len (s) {
00038 nat n= next_power_of_two (s);
00039 VERIFY(s <= n, "maximum size exceeded");
00040 ffter= new Ffter (n, fm); }
00041
00042 inline ~fft_truncated_transformer () { delete ffter; }
00043
00044
00045 inline void
00046 fft_cross_range (C* lp, C* rp, nat stride, nat shift, nat nr) {
00047 U* a= ffter->roots + (shift << 1);
00048 if (shift == 0)
00049 for (nat i=nr; i!=0; i--) {
00050 R::fft_cross (lp, rp);
00051 lp += stride; rp +=stride;
00052 }
00053 else
00054 for (nat i=nr; i!=0; i--) {
00055 R::dfft_cross (lp, rp, a);
00056 lp += stride; rp += stride;
00057 }
00058 }
00059
00060 inline void
00061 dtft (C* c, nat stride, nat s, nat shift, nat steps) {
00062
00063
00064 if (s == 0 || steps == 0) return;
00065 nat todo= steps - 1;
00066 nat w = (nat) 1 << todo;
00067 fft_cross_range (c, c + (stride<<todo), stride, shift>>todo, w);
00068 steps--;
00069 if (s >= w) {
00070 ffter->dfft (c, stride, shift, steps);
00071 s -= w; shift += w >> 1; c += stride<<todo;
00072 }
00073 dtft (c, stride, s, shift, steps);
00074 }
00075
00076
00077 inline void
00078 itft_flip_range (C* lp, C* rp, nat stride, nat shift, nat nr) {
00079 U* a= ffter->roots + (shift << 1);
00080 if (shift == 0)
00081 for (nat i=nr; i!=0; i--) {
00082 R::itft_flip (lp, rp);
00083 lp += stride; rp += stride;
00084 }
00085 else
00086 for (nat i=nr; i!=0; i--) {
00087 R::itft_flip (lp, rp, a);
00088 lp += stride; rp += stride;
00089 }
00090 }
00091
00092 inline void
00093 tft_cross_range (C* lp, C* rp, nat stride, nat shift, nat nr) {
00094 U* a= ffter->roots + (shift << 1);
00095 if (shift == 0)
00096 for (nat i=nr; i!=0; i--) {
00097 R::dtft_cross (lp, rp);
00098 lp += stride; rp += stride;
00099 }
00100 else
00101 for (nat i=nr; i!=0; i--) {
00102 R::dtft_cross (lp, rp, a);
00103 lp += stride; rp += stride;
00104 }
00105 }
00106
00107 inline void
00108 ifft_cross_range (C* lp, C* rp, nat stride, nat shift, nat nr) {
00109 U* a= ffter->roots + (shift << 1);
00110 if (shift == 0)
00111 for (nat i=nr; i!=0; i--) {
00112 R::fft_cross (lp, rp);
00113 lp += stride; rp += stride;
00114 }
00115 else
00116 for (nat i=nr; i!=0; i--) {
00117 R::ifft_cross (lp, rp, a+1);
00118 lp += stride; rp += stride;
00119 }
00120 }
00121
00122 inline void
00123 itft (C* c, nat stride, nat s, nat shift, nat steps) {
00124
00125
00126 if (s == 0 || steps == 0) return;
00127 nat todo= steps - 1;
00128 nat l = (nat) 1 << steps;
00129 nat w = (nat) 1 << todo;
00130 if (s < w) {
00131 tft_cross_range (c + stride*s, c + stride*(s+w),
00132 stride, shift>>todo, w-s);
00133 itft (c, stride, s, shift, steps-1);
00134 itft_flip_range (c, c + (stride<<todo), stride, shift>>todo, s);
00135 }
00136 else {
00137 ffter->ifft (c, stride, shift, steps-1);
00138 itft_flip_range (c + stride*(s-w), c + stride*s,
00139 stride, shift>>todo, l-s);
00140 if (s == w) return;
00141 itft (c + (stride<<todo), stride, s-w, shift + (w>>1), steps-1);
00142 ifft_cross_range (c, c + (stride<<todo), stride, shift>>todo, s-w);
00143 }
00144 }
00145
00146 inline void
00147 dtft (C* c, nat stride, nat s, nat shift) {
00148 dtft (c, stride, s, shift, ffter->depth); }
00149
00150 inline void
00151 itft (C* c, nat stride, nat s, nat shift) {
00152 itft (c, stride, s, shift, ffter->depth); }
00153
00154 inline void
00155 direct_transform (C* c) {
00156 nat n= ((nat) 1) << ffter->depth;
00157 if (len == n)
00158 ffter->direct_transform (c);
00159 else {
00160 NVec::clear (c + len, n - len);
00161 dtft (c, 1, len, 0);
00162 NVec::clear (c + len, n - len);
00163 }
00164 }
00165
00166 inline void
00167 inverse_transform (C* c) {
00168 nat n= ((nat) 1) << ffter->depth;
00169 if (len == n)
00170 ffter->inverse_transform (c);
00171 else {
00172 NVec::clear (c + len, n - len);
00173 itft (c, 1, len, 0);
00174 S x= invert (binpow (S (2), ffter->depth));
00175 NVec::template vec_unary_scalar<typename R::fft_mul_sc_op> (c, x, len);
00176 NVec::clear (c + len, n - len);
00177 }
00178 }
00179 };
00180
00181 #undef TMPL
00182 }
00183 #endif //__MMX__FFT_TRUNCATED__HPP