00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__FFT_TRIADIC_THREADS__HPP
00014 #define __MMX__FFT_TRIADIC_THREADS__HPP
00015 #include <algebramix/fft_triadic_naive.hpp>
00016 #include <basix/threads.hpp>
00017
00018 namespace mmx {
00019
00020
00021
00022
00023
00024 template<typename C, typename FFTER= fft_triadic_naive_transformer<C>,
00025 nat thr= 9>
00026 class fft_triadic_threads_transformer {
00027 public:
00028 typedef implementation<vector_linear,vector_naive> NVec;
00029 typedef typename FFTER::V V;
00030 typedef typename V::template helper<C>::roots_type R;
00031 typedef typename R::U U;
00032 static const nat min_reps= 27;
00033
00034 FFTER* ffter;
00035 nat depth;
00036 nat len;
00037 U* roots;
00038 U* stoor;
00039
00040 public:
00041 inline fft_triadic_threads_transformer (nat n, const format<C>& fm):
00042 ffter (new FFTER (n, fm)),
00043 depth (ffter->depth), len (ffter->len),
00044 roots (ffter->roots), stoor (ffter->stoor) {}
00045 inline ~fft_triadic_threads_transformer () { delete ffter; }
00046
00047 template<typename CC> static inline void
00048 copy (CC* d, nat drs, nat dcs, CC* s, nat srs, nat scs, nat r, nat c) {
00049 for (nat j=0; j<c; j++, d+=dcs, s+=scs) {
00050 CC* dd= d; CC* ss= s;
00051 for (nat i=0; i<r; i++, dd+=drs, ss+=srs)
00052 *dd= *ss;
00053 }
00054 }
00055
00056 template<typename CC>
00057 struct outer_fft_triadic_task_rep: public task_rep {
00058 FFTER* ffter;
00059 bool direct;
00060 CC *buf;
00061 nat thr_reps, tot_reps;
00062 CC *c;
00063 nat stride, shift, steps;
00064 nat l;
00065 public:
00066 inline outer_fft_triadic_task_rep
00067 (FFTER* ffter2, bool direct2,
00068 CC* buf2, nat thr_reps2, nat tot_reps2,
00069 CC* c2, nat stride2, nat shift2, nat steps2):
00070 ffter (ffter2), direct (direct2),
00071 buf (buf2), thr_reps (thr_reps2), tot_reps (tot_reps2),
00072 c (c2), stride (stride2), shift (shift2), steps (steps2) {
00073 l= binpow ((nat) 3, steps); }
00074 inline ~outer_fft_triadic_task_rep () {}
00075 void execute () {
00076 for (nat i=0; i*min_reps < thr_reps; i++, c += stride * min_reps) {
00077 CC* aux= buf;
00078 copy (buf, 1, min_reps, c, stride, stride * tot_reps,
00079 min_reps, l);
00080 if (direct)
00081 for (nat j=0; j<min_reps; j++, aux++)
00082 ffter->dfft_triadic (aux, min_reps, shift, steps);
00083 else
00084 for (nat j=0; j<min_reps; j++, aux++)
00085 ffter->ifft_triadic (aux, min_reps, shift, steps);
00086 copy (c, stride, stride * tot_reps, buf, 1, min_reps,
00087 min_reps, l);
00088 }
00089 }
00090 };
00091
00092 template<typename CC>
00093 struct inner_fft_triadic_task_rep: public task_rep {
00094 FFTER* ffter;
00095 bool direct;
00096 nat start, inc, total;
00097 CC *c;
00098 nat stride, shift, steps;
00099 nat l;
00100 public:
00101 inline inner_fft_triadic_task_rep
00102 (FFTER* ffter2, bool direct2, nat start2, nat inc2, nat total2,
00103 CC* c2, nat stride2, nat shift2, nat steps2):
00104 ffter (ffter2), direct (direct2),
00105 start (start2), inc (inc2), total (total2),
00106 c (c2), stride (stride2), shift (shift2), steps (steps2) {
00107 l= binpow ((nat) 3, steps); }
00108 inline ~inner_fft_triadic_task_rep () {}
00109 void execute () {
00110 if (direct)
00111 for (nat i=start; i<total; i+=inc)
00112 ffter->dfft_triadic (c + i * l * stride, stride,
00113 shift + i * l, steps);
00114 else
00115 for (nat i=start; i<total; i+=inc)
00116 ffter->ifft_triadic (c + i * l * stride, stride,
00117 shift + i * l, steps);
00118 }
00119 };
00120
00121 template<typename CC> void
00122 fft_triadic (bool direct, CC* c, nat stride, nat shift, nat steps) {
00123 nat nt= threads_number;
00124 nat half1= steps >> 1;
00125 nat half2= steps - half1;
00126 nat len1= binpow ((nat) 3, half1);
00127 nat len2= binpow ((nat) 3, half2);
00128 if (steps <= thr || len2 <= min_reps || nt == 1) {
00129 if (direct) ffter->dfft_triadic (c, stride, shift, steps);
00130 else ffter->ifft_triadic (c, stride, shift, steps);
00131 }
00132 else {
00133 for (nat stage=0; stage<2; stage++) {
00134 if ((stage == 0) ^ (!direct)) {
00135
00136
00137
00138
00139
00140
00141
00142 nat bsz= min_reps * len1;
00143 CC* buf= mmx_new<CC> (nt * bsz);
00144 task tasks[nt];
00145 for (nat i=0; i<nt; i++) {
00146 nat sta= min_reps * (( i * len2) / (min_reps * nt));
00147 nat end= min_reps * (((i+1) * len2) / (min_reps * nt));
00148 tasks[i]= new outer_fft_triadic_task_rep<CC>
00149 (ffter, direct, buf + i*bsz, end-sta, len2,
00150 c + sta*stride, stride, shift / len2, half1);
00151 }
00152 threads_execute (tasks, nt);
00153 mmx_delete<CC> (buf, nt * bsz);
00154 }
00155 else {
00156
00157
00158
00159
00160
00161
00162
00163 task tasks[nt];
00164 for (nat i=0; i<nt; i++)
00165 tasks[i]= new inner_fft_triadic_task_rep<CC>
00166 (ffter, direct, i, nt, len1, c, stride, shift, half2);
00167 threads_execute (tasks, nt);
00168 }
00169 }
00170 }
00171 }
00172
00173 template<typename CC> inline void
00174 dfft_triadic (CC* c, nat stride, nat shift, nat steps) {
00175 if (steps <= thr) ffter->dfft_triadic (c, stride, shift, steps);
00176 else fft_triadic (true, c, stride, shift, steps);
00177 }
00178
00179 template<typename CC> inline void
00180 ifft_triadic (CC* c, nat stride, nat shift, nat steps) {
00181 if (steps <= thr) ffter->ifft_triadic (c, stride, shift, steps);
00182 else fft_triadic (false, c, stride, shift, steps);
00183 }
00184
00185 template<typename CC> inline void
00186 dfft_triadic (CC* c, nat stride, nat shift, nat steps,
00187 nat step1, nat step2) {
00188 if (step1 == 0 && step2 == steps && steps > thr)
00189 fft_triadic (true, c, stride, shift, steps);
00190 else ffter->dfft_triadic (c, stride, shift, steps, step1, step2);
00191 }
00192
00193 template<typename CC> inline void
00194 ifft_triadic (CC* c, nat stride, nat shift, nat steps,
00195 nat step1, nat step2) {
00196 if (step1 == 0 && step2 == steps && steps > thr)
00197 fft_triadic (false, c, stride, shift, steps);
00198 else ffter->ifft_triadic (c, stride, shift, steps, step1, step2);
00199 }
00200
00201 inline void
00202 direct_transform_triadic (C* c) {
00203 dfft_triadic (c, 1, 0, depth);
00204 }
00205
00206 inline void
00207 inverse_transform_triadic (C* c, bool shift=true) {
00208 ifft_triadic (c, 1, 0, depth);
00209 if (shift) NVec::mul (c, invert (C (len)), len);
00210 }
00211 };
00212
00213 }
00214 #endif //__MMX__FFT_TRIADIC_THREADS__HPP