00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__SERIES_FAST__HPP
00014 #define __MMX__SERIES_FAST__HPP
00015 #include <algebramix/series.hpp>
00016 #include <algebramix/series_vector.hpp>
00017 #include <algebramix/fkt_transform.hpp>
00018 #include <algebramix/fft_naive.hpp>
00019
00020 namespace mmx {
00021 #define TMPL template<typename C,typename V>
00022 #define Series series<C,V>
00023 #define Series_rep series_rep<C,V>
00024 #define Vector vector<C>
00025 #define Series_vector series<Vector,V>
00026 #define Series_vector_rep series_rep<Vector,V>
00027
00028
00029
00030
00031
00032 struct series_fast {};
00033
00034 template<typename F, typename V>
00035 struct implementation<F,V,series_fast>:
00036 public implementation<F,V,series_naive> {};
00037
00038
00039
00040
00041
00042 template<typename U>
00043 struct implementation<series_multiply,U,series_fast>:
00044 public implementation<series_abstractions,U>
00045 {
00046
00047
00048 #define TRANSFORM_NAIVE 0
00049 #define TRANSFORM_KARATSUBA 1
00050 #define TRANSFORM_FFT 2
00051
00052 template <typename C>
00053 struct level_info {
00054 MMX_ALLOCATORS
00055 typedef C* C_ptr;
00056
00057 nat n;
00058 nat k;
00059 nat tsz;
00060 nat type;
00061 C_ptr head[2];
00062 C_ptr tail[2];
00063
00064 level_info () {
00065 head[0]= head[1]= NULL;
00066 tail[0]= tail[1]= NULL;
00067 }
00068 ~level_info () {
00069 if (head[0] != NULL) mmx_classical_delete<C> (head[0]);
00070 if (head[1] != NULL) mmx_classical_delete<C> (head[1]);
00071 if (tail[0] != NULL) mmx_classical_delete<C> (tail[0]);
00072 if (tail[1] != NULL) mmx_classical_delete<C> (tail[1]);
00073 }
00074 };
00075
00076 static nat
00077 get_inside_multiplier (nat n) {
00078 if (n <= (1 << 7)) return 16;
00079 if (n <= (1 << 9)) return 4;
00080 if (n <= (1 << 10)) return 8;
00081 if (n <= (1 << 15)) return 16;
00082 if (n <= (1 << 23)) return 32;
00083 return 64;
00084 }
00085
00086 static nat
00087 get_border_multiplier (nat n) {
00088 if (n <= (1 << 4)) return 2;
00089 if (n <= (1 << 8)) return 4;
00090 if (n <= (1 << 13)) return 8;
00091 if (n <= (1 << 21)) return 16;
00092 return 32;
00093 }
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117 static vector<nat>
00118 determine_sizes (nat n) {
00119 if (n <= 16) return vec<nat> (1);
00120 vector<nat> v;
00121 nat k= get_inside_multiplier (n);
00122
00123 n /= k;
00124 while (n != 1) {
00125 k= get_border_multiplier (n);
00126 v << k;
00127 n /= k;
00128 }
00129 v << 1;
00130 for (nat i=0; i<N(v)/2; i++)
00131 swap (v[i], v[N(v) - 1 - i]);
00132 return v;
00133 }
00134
00135
00136
00137
00138
00139 TMPL
00140 class nrelax_mul_series_rep: public Series_rep {
00141 public:
00142 typedef typename series_polynomial_helper<C,V >::PV PV;
00143 typedef implementation<polynomial_linear,PV> Pol;
00144 typedef fkt_package<polynomial_naive> Fkt;
00145
00146 protected:
00147 Series f[2];
00148 nat sh[2];
00149 nat capacity[2];
00150 nat nr_levels;
00151 level_info<C>* info;
00152 nat xnr_levels;
00153 level_info<C>* xinfo;
00154
00155 public:
00156 nrelax_mul_series_rep (const Series& f2, const Series& g2, nat n):
00157 Series_rep (CF(f2))
00158 {
00159 f[0]= f2; f[1]= g2;
00160 sh[0]= 1; sh[1]= 1;
00161
00162 const vector<nat> v= determine_sizes (n);
00163 nr_levels= N(v);
00164 info= mmx_new<level_info<C> > (nr_levels);
00165 nat sz= 1;
00166 for (nat level=0; level< nr_levels; level++) {
00167 bool last= (level == nr_levels - 1);
00168 nat n= sz * v[level];
00169 nat k= (last? 1: v[level+1]);
00170 info[level].n = n;
00171 info[level].k = k;
00172 nat tsz, type;
00173 if (v[level] == 1) {
00174 tsz = 1;
00175 type= TRANSFORM_NAIVE;
00176 }
00177 else if (v[level] == 2) {
00178 tsz = 3 * info[level-1].tsz;
00179 type= TRANSFORM_KARATSUBA;
00180 }
00181 else {
00182 tsz= 2 * n;
00183 type= TRANSFORM_FFT;
00184 }
00185 info[level].tsz = tsz;
00186 info[level].type= type;
00187
00188
00189
00190
00191
00192 if (last) {
00193 info[level].tail[0]= mmx_classical_new<C> (2 * tsz);
00194 info[level].tail[1]= mmx_classical_new<C> (2 * tsz);
00195 capacity[0]= 2;
00196 capacity[1]= 2;
00197 }
00198 else {
00199 if (sh[0] != 0) {
00200 info[level].head[0]= mmx_classical_new<C> (k * tsz);
00201 info[level].tail[1]= mmx_classical_new<C> (k * tsz);
00202 }
00203 if (sh[1] != 0) {
00204 info[level].head[1]= mmx_classical_new<C> (k * tsz);
00205 info[level].tail[0]= mmx_classical_new<C> (k * tsz);
00206 }
00207 }
00208 sz *= v[level];
00209
00210
00211 }
00212 xinfo= info;
00213 xnr_levels= nr_levels;
00214 if (sh[0] == 0 && sh[1] == 0) {
00215 info += nr_levels-1;
00216 nr_levels= 1;
00217 }
00218 }
00219
00220 ~nrelax_mul_series_rep () {
00221 this->l= allocated (this->l);
00222 mmx_delete<level_info<C> > (xinfo, xnr_levels); }
00223
00224 syntactic expression (const syntactic& z) const {
00225 return flatten (f[0], z) * flatten (f[1], z); }
00226
00227 nat allocated (nat l) {
00228 if (l == 0) return 0;
00229 else return l + 2 * info[nr_levels-1].n; }
00230 void Set_order (nat l2) {
00231 if (l2 <= this->l) return;
00232 nat old_allocated= allocated (this->l);
00233 nat new_allocated= allocated (l2);
00234 C* b= mmx_new<C> (new_allocated);
00235 Pol::copy (b, this->a, old_allocated);
00236 Pol::clear (b + old_allocated, new_allocated - old_allocated);
00237 mmx_delete<C> (this->a, old_allocated);
00238 this->a= b;
00239 this->l= l2;
00240 }
00241
00242 void Increase_order (nat l) {
00243 Series_rep::Increase_order (l);
00244 increase_order (f[0], l);
00245 increase_order (f[1], l); }
00246
00247 void direct_transform (C* dest, nat ld, const C* src, nat ls, nat type) {
00248
00249 if (type == TRANSFORM_NAIVE) dest[0]= src[0];
00250 else if (type == TRANSFORM_KARATSUBA) {
00251 Pol::copy (dest, src, ls);
00252 Fkt::direct_fkt (dest, ls, ld);
00253 }
00254 else {
00255 Pol::copy (dest, src, ls);
00256 Pol::clear (dest + ls, ls);
00257 fft_naive_transformer<C> ffter (ld, get_format (*dest));
00258 ffter.direct_transform (dest);
00259 }
00260 }
00261
00262 void inverse_transform (C* dest, nat n, nat tsz, nat type) {
00263 if (type == TRANSFORM_NAIVE);
00264 else if (type == TRANSFORM_KARATSUBA) Fkt::inverse_fkt (dest, tsz, n);
00265 else {
00266 fft_naive_transformer<C> ffter (tsz, get_format (*dest));
00267 ffter.inverse_transform (dest);
00268 }
00269 }
00270
00271 void direct_transform (nat which) {
00272 for (nat level= 0; level < nr_levels; level++) {
00273 typedef C* C_ptr;
00274 nat n = info[level].n;
00275 nat tsh = sh[0] + sh[1];
00276 nat cur = this->n + tsh - n * sh[1-which];
00277 if (cur % n != 0) break;
00278 if (sh[1-which] != 0 && this->n + 1 < n) break;
00279
00280 bool last= (level == nr_levels - 1);
00281 nat k = info[level].k;
00282 nat tsz = info[level].tsz;
00283 nat type= info[level].type;
00284 C_ptr& head= info[level].head[which];
00285 C_ptr& tail= info[level].tail[which];
00286 nat Cur= cur / n;
00287 nat Mod= Cur % k;
00288
00289 if (f[which]->n < this->n + 1 && which == 1) {
00290 mmerr << "\n>>> n= " << this->n
00291 << " and only " << f[which]->n << " terms\n";
00292 ASSERT (f[which]->n >= this->n + 1, "insufficient number of terms");
00293 }
00294 const C* seg= (sh[1-which] == 0?
00295 f[which] (this->n, this->n + n):
00296 f[which] (this->n + 1 - n, this->n + 1));
00297
00298 if (last) {
00299 nat Cap= capacity[which];
00300
00301 if (Cur >= Cap) {
00302 nat old_cap= Cap * tsz;
00303 nat new_cap= old_cap << 1;
00304 C* a= mmx_classical_new<C> (new_cap);
00305 Pol::copy (a, tail, old_cap);
00306 Pol::clear (a + old_cap, new_cap - old_cap);
00307 mmx_classical_delete<C> (tail);
00308 tail= a;
00309 capacity[which]= Cap << 1;
00310
00311 }
00312
00313 direct_transform (tail + Cur * tsz, tsz, seg, n, type);
00314
00315 }
00316 else if (Cur < k * sh[which]) {
00317
00318 direct_transform (head + Mod * tsz, tsz, seg, n, type);
00319
00320 }
00321 else if (tail != NULL) {
00322
00323 direct_transform (tail + Mod * tsz, tsz, seg, n, type);
00324
00325 }
00326 }
00327 }
00328
00329 inline void accumulate (C* dest, const C* s1, const C* s2, nat len) {
00330
00331
00332
00333 Pol::mul_add (dest, s1, s2, len);
00334
00335 }
00336
00337 C next () {
00338
00339
00340 (void) f[0][this->n];
00341 (void) f[1][this->n];
00342 direct_transform (0);
00343 direct_transform (1);
00344 for (nat level= 0; level < nr_levels; level++) {
00345 nat tsh = sh[0] + sh[1];
00346 nat cur = this->n + tsh;
00347 nat n = info[level].n;
00348 if (cur % n != 0) break;
00349 nat Cur = cur/n;
00350 if (Cur < tsh) break;
00351
00352 bool last= (level == nr_levels - 1);
00353 nat msh = min (sh[0], sh[1]);
00354 nat k = info[level].k;
00355 nat tsz = info[level].tsz;
00356 nat type= info[level].type;
00357 const C* h0 = info[level].head[0];
00358 const C* h1 = info[level].head[1];
00359 const C* t0 = info[level].tail[0];
00360 const C* t1 = info[level].tail[1];
00361 nat Mod = Cur % k;
00362
00363
00364 C* acc= mmx_new<C> (tsz);
00365 Pol::clear (acc, tsz);
00366 if (last)
00367 for (nat i=sh[0]; i<=Cur-sh[1]; i++)
00368 accumulate (acc, t0 + i*tsz, t1 + (Cur-i)*tsz, tsz);
00369 else {
00370 if (Cur < 2 * msh * k) {
00371 nat start= 1 ; if (Cur > k) start= Cur - k + 1;
00372 nat end = Cur - 1; if (Cur > k) end = k - 1;
00373 for (nat i=start; i<=end; i++)
00374 accumulate (acc, h0 + i*tsz, h1 + (Cur-i)*tsz, tsz);
00375 }
00376 if (Cur >= msh * k && Mod != 0) {
00377
00378 if (sh[0] != 0)
00379 for (nat i=1; i<=Mod; i++)
00380 accumulate (acc, h0 + i*tsz, t1 + (Mod-i)*tsz, tsz);
00381 if (sh[1] != 0)
00382 for (nat i=0; i<=Mod-1; i++)
00383 accumulate (acc, t0 + i*tsz, h1 + (Mod-i)*tsz, tsz);
00384 }
00385 }
00386 inverse_transform (acc, n, tsz, type);
00387
00388 Pol::add (this->a + this->n, acc, 2*n - 1);
00389
00390 if (Cur >= msh * k && Mod == 0 && !last) {
00391
00392 for (nat Mod= 0; Mod<k-1; Mod++) {
00393 Pol::clear (acc, tsz);
00394 for (nat i=Mod+1; i<k; i++) {
00395 if (sh[0] != 0)
00396 accumulate (acc, h0 + i*tsz, t1 + (k+Mod-i)*tsz, tsz);
00397 if (sh[1] != 0)
00398 accumulate (acc, t0 + i*tsz, h1 + (k+Mod-i)*tsz, tsz);
00399 }
00400 inverse_transform (acc, n, tsz, type);
00401 Pol::add (this->a + this->n + Mod*n, acc, 2*n - 1);
00402 }
00403 }
00404 mmx_delete<C> (acc, tsz);
00405 }
00406 return this->a [this->n];
00407 }
00408 };
00409
00410 TMPL static Series
00411 nrelax_mul (const Series& f, const Series& g, nat n) {
00412 return (Series_rep*) new nrelax_mul_series_rep<C,V> (f, g, n);
00413 }
00414
00415
00416
00417 TMPL
00418 class mul_series_rep: public
00419 implementation<series_abstractions,V>
00420 ::template binary_series_rep<mul_op,C,V> {
00421 protected:
00422 Series prod;
00423 nat N;
00424 public:
00425 inline mul_series_rep (const Series& f, const Series& g):
00426 implementation<series_abstractions,V>
00427 ::template binary_series_rep<mul_op,C,V > (f, g),
00428 prod (nrelax_mul (f, g, 1)), N (1) {}
00429 C next () { return prod [this->n]; }
00430 void Increase_order (nat l) {
00431 Series_rep::Increase_order (l);
00432 increase_order (prod, l);
00433 if (l < 2*N) return;
00434 while (l >= 2*N) N= 2*N;
00435
00436 prod= nrelax_mul (this->f, this->g, N);
00437 }
00438 };
00439
00440 TMPL static inline Series
00441 ser_mul (const Series& f, const Series& g) {
00442 typedef mul_series_rep<C,V> Mul_rep;
00443 if (is_exact_zero (f) || is_exact_zero (g))
00444 return Series (CF(f));
00445 return (Series_rep*) new Mul_rep (f, g); }
00446
00447 TMPL static inline Series
00448 ser_truncate_mul (const Series& f, const Series& g, nat nf, nat ng) {
00449 typedef mul_series_rep<C,V> Mul_rep;
00450 if (is_exact_zero (f) || is_exact_zero (g) || nf == 0 || ng == 0)
00451 return Series (CF(f));
00452 return (Series_rep*)
00453 new Mul_rep (piecewise (f, Series (CF(f)), nf),
00454 piecewise (g, Series (CF(g)), ng)); }
00455
00456 };
00457
00458 #undef TMPL
00459 #undef Series
00460 #undef Series_rep
00461 #undef Vector
00462 #undef Series_vector
00463 #undef Series_vector_rep
00464 }
00465 #endif // __MMX__SERIES_FAST__HPP