00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__FFT_BLOCKS__HPP
00014 #define __MMX__FFT_BLOCKS__HPP
00015 #include <basix/vector.hpp>
00016 #include <basix/threads.hpp>
00017 #include <algebramix/fft_naive.hpp>
00018
00019 namespace mmx {
00020
00021
00022
00023
00024
00025 template<typename C, typename FFTER= fft_naive_transformer<C>,
00026 nat log2_outer_block_size= 9,
00027 nat log2_block_number= 7,
00028
00029 nat log2_inner_block_size= 11,
00030 nat threshold= 12>
00031 class fft_blocks_transformer {
00032 public:
00033 typedef implementation<vector_linear,vector_naive> NVec;
00034 typedef typename FFTER::R R;
00035 typedef typename R::U U;
00036 typedef typename R::S S;
00037
00038 FFTER* ffter;
00039 nat depth;
00040 nat len;
00041 U* roots;
00042
00043 public:
00044 inline fft_blocks_transformer (nat n, const format<C>& fm):
00045 ffter (new FFTER (n, fm)),
00046 depth (ffter->depth), len (ffter->len), roots (ffter->roots) {}
00047
00048 inline ~fft_blocks_transformer () { delete ffter; }
00049
00050 static inline void
00051 delocate (C* c1, nat s1, C* c0, nat s0, nat n) {
00052 for (nat i= 0; i < s1; i++, c0 += s0) {
00053 C* cc0= c0, * cc1= c1 + i;
00054 for (nat k= 0; k < n; k++, cc0++, cc1+= s1)
00055 *cc1= *cc0; } }
00056
00057 static inline void
00058 relocate (C* c1, nat s1, C* c0, nat s0, nat n) {
00059 for (nat i= 0; i < s1; i++, c0 += s0) {
00060 C* cc0= c0, * cc1= c1 + i;
00061 for (nat k= 0; k < n; k++, cc0++, cc1+= s1)
00062 *cc0= *cc1; } }
00063
00064 inline void
00065 dfft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00066 if (steps <= max (threshold, log2_inner_block_size))
00067 ffter->dfft (c, stride, shift, steps, step1, step2);
00068 else {
00069 nat temp_len= default_aligned_size<C>
00070 ((nat) 1<<(log2_outer_block_size+log2_block_number));
00071 C* c0, * temp= mmx_new<C> (temp_len);
00072 nat step, block, blocks, todo, k_beg, k_end;
00073
00074 for (step= step1;
00075 step < min (steps - log2_inner_block_size, step2);
00076 step += block) {
00077 block = min (min (steps - log2_inner_block_size - step, step2 - step),
00078 log2_outer_block_size);
00079 todo = steps - 1 - step;
00080 blocks= todo - block + 1;
00081 c0= c;
00082 for (nat j= 0; j < ((nat) 1<<step); j++) {
00083 for (k_beg= 0; k_beg < ((nat) 1<<blocks); k_beg= k_end) {
00084 k_end= min (((nat) 1<<blocks),
00085 k_beg + ((nat) 1<< log2_block_number));
00086 delocate (temp, (nat) 1<<block,
00087 c0 + k_beg, stride<<blocks, k_end - k_beg);
00088 for (nat k= 0; k < k_end - k_beg; k++)
00089 ffter->dfft (temp+(k<<block), 1,
00090 ((shift>>todo) + j)<<(block-1), block);
00091 relocate (temp, (nat) 1<<block,
00092 c0+k_beg, stride<<blocks, k_end - k_beg);
00093 }
00094 c0 += (stride<<(todo+1));
00095 }
00096 }
00097 mmx_delete<C> (temp, temp_len);
00098 if (step < step2) {
00099 c0= c;
00100 todo= steps - 1 - step;
00101 for (nat j= 0; j < ((nat) 1<<step); j++) {
00102 ffter->dfft (c0, stride, ((shift>>todo) + j) << todo,
00103 todo+1, 0, step2-step);
00104 c0 += (stride<<(todo+1));
00105 }
00106 }
00107 }
00108 }
00109
00110 inline void
00111 ifft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00112 if (steps <= max (threshold, log2_inner_block_size))
00113 ffter->ifft (c, stride, shift, steps, step1, step2);
00114 else {
00115 nat temp_len= default_aligned_size<C>
00116 ((nat) 1<<(log2_outer_block_size+log2_block_number));
00117 C* c0, * temp= mmx_new<C> (temp_len);
00118 nat step, block, blocks, todo, k_beg, k_end;
00119
00120 if (log2_inner_block_size > steps - step2) {
00121 step= steps - log2_inner_block_size;
00122 c0= c;
00123 todo= steps - 1 - step;
00124 for (nat j= 0; j < ((nat) 1<<step); j++) {
00125 ffter->ifft (c0, stride, ((shift>>todo) + j) <<todo,
00126 todo+1, max(0, (int) step1- (int) step), step2-step);
00127 c0 += (stride<<(todo+1));
00128 }
00129 }
00130 else
00131 step= step2;
00132
00133 while (step > step1) {
00134 block = min (step - step1, log2_outer_block_size);
00135 step -= block;
00136 todo = steps - 1 - step;
00137 blocks= todo - block + 1;
00138 c0= c;
00139 for (nat j= 0; j < ((nat) 1<<step); j++) {
00140 for (k_beg= 0; k_beg < ((nat) 1<<blocks); k_beg= k_end) {
00141 k_end= min (((nat) 1<<blocks),
00142 k_beg + ((nat) 1<< log2_block_number));
00143 delocate (temp, (nat) 1<<block,
00144 c0 + k_beg, stride<<blocks, k_end - k_beg);
00145 for (nat k= 0; k < k_end - k_beg; k++)
00146 ffter->ifft (temp+(k<<block), 1,
00147 ((shift>>todo) + j)<<(block-1), block);
00148 relocate (temp, (nat) 1<<block,
00149 c0+k_beg, stride<<blocks, k_end - k_beg);
00150 }
00151 c0 += (stride<<(todo+1));
00152 }
00153 }
00154 mmx_delete<C> (temp, temp_len);
00155 }
00156 }
00157
00158 inline void
00159 dfft (C* c, nat stride, nat shift, nat steps) {
00160 dfft (c, stride, shift, steps, 0, steps); }
00161
00162 inline void
00163 ifft (C* c, nat stride, nat shift, nat steps) {
00164 ifft (c, stride, shift, steps, 0, steps); }
00165
00166 inline void
00167 direct_transform (C* c) {
00168 dfft (c, 1, 0, depth); }
00169
00170 inline void
00171 inverse_transform (C* c, bool divide=true) {
00172 ifft (c, 1, 0, depth);
00173 if (divide) {
00174 S x= binpow (S (2), depth);
00175 x= invert (x);
00176 NVec::template vec_unary_scalar<typename R::fft_mul_sc_op> (c, x, len);
00177 }
00178 }
00179 };
00180
00181 }
00182 #endif //__MMX__FFT_BLOCKS__HPP