00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__FFT_THREADS__HPP
00014 #define __MMX__FFT_THREADS__HPP
00015 #include <basix/threads.hpp>
00016 #include <algebramix/vector_threads.hpp>
00017 #include <algebramix/fft_blocks.hpp>
00018
00019 namespace mmx {
00020
00021
00022
00023
00024
00025 template<typename C,
00026 typename FFTER= fft_blocks_transformer<C,
00027 fft_naive_transformer<C>,
00028 8, 6, 10>,
00029 nat thr= 14>
00030 class fft_threads_transformer {
00031 public:
00032 typedef implementation<vector_linear,vector_threads<vector_naive> > Vec;
00033 typedef typename FFTER::R R;
00034 typedef typename R::U U;
00035 typedef typename R::S S;
00036 static const nat min_reps= 16;
00037
00038 FFTER* ffter;
00039 nat depth;
00040 nat len;
00041 U* roots;
00042
00043 public:
00044 inline fft_threads_transformer (nat n, const format<C>& fm):
00045 ffter (new FFTER (n, fm)),
00046 depth (ffter->depth), len (ffter->len), roots (ffter->roots) {}
00047 inline ~fft_threads_transformer () { delete ffter; }
00048
00049 static inline void
00050 copy (C* d, nat drs, nat dcs, C* s, nat srs, nat scs, nat r, nat c) {
00051 for (nat j=0; j<c; j++, d+=dcs, s+=scs) {
00052 C* dd= d; C* ss= s;
00053 for (nat i=0; i<r; i++, dd+=drs, ss+=srs)
00054 *dd= *ss;
00055 }
00056 }
00057
00058 struct outer_fft_task_rep: public task_rep {
00059 FFTER* ffter;
00060 bool direct;
00061 C *buf;
00062 nat thr_reps, tot_reps;
00063 C *c;
00064 nat stride, shift, steps;
00065 public:
00066 inline outer_fft_task_rep
00067 (FFTER* ffter2, bool direct2,
00068 C* buf2, nat thr_reps2, nat tot_reps2,
00069 C* c2, nat stride2, nat shift2, nat steps2):
00070 ffter (ffter2), direct (direct2),
00071 buf (buf2), thr_reps (thr_reps2), tot_reps (tot_reps2),
00072 c (c2), stride (stride2), shift (shift2), steps (steps2) {}
00073 inline ~outer_fft_task_rep () {}
00074 void execute () {
00075 for (nat i=0; i*min_reps < thr_reps; i++, c += stride * min_reps) {
00076 C* aux= buf;
00077 copy (buf, 1, min_reps, c, stride, stride * tot_reps,
00078 min_reps, 1 << steps);
00079 if (direct)
00080 for (nat j=0; j<min_reps; j++, aux++)
00081 ffter->dfft (aux, min_reps, shift, steps);
00082 else
00083 for (nat j=0; j<min_reps; j++, aux++)
00084 ffter->ifft (aux, min_reps, shift, steps);
00085 copy (c, stride, stride * tot_reps, buf, 1, min_reps,
00086 min_reps, 1 << steps);
00087 }
00088 }
00089 };
00090
00091 struct inner_fft_task_rep: public task_rep {
00092 FFTER* ffter;
00093 bool direct;
00094 nat start, inc, total;
00095 C *c;
00096 nat stride, shift, steps;
00097 public:
00098 inline inner_fft_task_rep
00099 (FFTER* ffter2, bool direct2, nat start2, nat inc2, nat total2,
00100 C* c2, nat stride2, nat shift2, nat steps2):
00101 ffter (ffter2), direct (direct2),
00102 start (start2), inc (inc2), total (total2),
00103 c (c2), stride (stride2), shift (shift2), steps (steps2) {}
00104 inline ~inner_fft_task_rep () {}
00105 void execute () {
00106 if (direct)
00107 for (nat i=start; i<total; i+=inc)
00108 ffter->dfft (c + (i<<steps) * stride, stride,
00109 shift + ((i << steps) >> 1), steps);
00110 else
00111 for (nat i=start; i<total; i+=inc)
00112 ffter->ifft (c + (i<<steps) * stride, stride,
00113 shift + ((i << steps) >> 1), steps);
00114 }
00115 };
00116
00117 void
00118 fft (bool direct, C* c, nat stride, nat shift, nat steps) {
00119 if (steps <= thr) {
00120 if (direct) ffter->dfft (c, stride, shift, steps);
00121 else ffter->ifft (c, stride, shift, steps);
00122 }
00123 else {
00124 nat half1= min(log_2 (threads_number), steps);
00125 nat half2= steps - half1;
00126
00127 for (nat stage=0; stage<2; stage++) {
00128 if ((stage == 0) ^ (!direct)) {
00129
00130
00131
00132
00133
00134 nat nt = threads_number;
00135 nat bsz= (min_reps << half1);
00136 C* buf= mmx_new<C> (nt * bsz);
00137 task tasks[nt];
00138 for (nat i=0; i<nt; i++) {
00139 nat tot= 1 << half2;
00140 nat sta= min_reps * (( i * tot) / (min_reps * nt));
00141 nat end= min_reps * (((i+1) * tot) / (min_reps * nt));
00142 tasks[i]= new outer_fft_task_rep
00143 (ffter, direct, buf + i*bsz, end-sta, tot,
00144 c + sta*stride, stride, shift >> half2, half1);
00145 }
00146 threads_execute (tasks, nt);
00147 mmx_delete<C> (buf, nt * bsz);
00148
00149 }
00150 else {
00151
00152
00153
00154
00155 nat nt= threads_number;
00156 task tasks[nt];
00157 for (nat i=0; i<nt; i++)
00158 tasks[i]= new inner_fft_task_rep
00159 (ffter, direct, i, nt, 1 << half1, c, stride, shift, half2);
00160 threads_execute (tasks, nt);
00161 }
00162 }
00163 }
00164 }
00165
00166 inline void
00167 dfft (C* c, nat stride, nat shift, nat steps) {
00168 if (steps <= thr) ffter->dfft (c, stride, shift, steps);
00169 else fft (true, c, stride, shift, steps);
00170 }
00171
00172 inline void
00173 ifft (C* c, nat stride, nat shift, nat steps) {
00174 if (steps <= thr) ffter->ifft (c, stride, shift, steps);
00175 else fft (false, c, stride, shift, steps);
00176 }
00177
00178 inline void
00179 dfft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00180 if (step1 == 0 && step2 == steps && steps > thr)
00181 fft (true, c, stride, shift, steps);
00182 else ffter->dfft (c, stride, shift, steps, step1, step2);
00183 }
00184
00185 inline void
00186 ifft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00187 if (step1 == 0 && step2 == steps && steps > thr)
00188 fft (false, c, stride, shift, steps);
00189 else ffter->ifft (c, stride, shift, steps, step1, step2);
00190 }
00191
00192 inline void
00193 direct_transform (C* c) {
00194 dfft (c, 1, 0, depth); }
00195
00196 inline void
00197 inverse_transform (C* c, bool divide=true) {
00198 ifft (c, 1, 0, depth);
00199 if (divide) {
00200 S x= binpow (S (2), depth);
00201 x= invert (x);
00202 Vec::template vec_unary_scalar<typename R::fft_mul_sc_op> (c, x, len);
00203 }
00204 }
00205 };
00206
00207 }
00208 #endif //__MMX__FFT_THREADS__HPP