00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #ifndef __MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP
00016 #define __MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP
00017 #include <algebramix/fft_triadic_naive.hpp>
00018 #include <algebramix/polynomial_dicho.hpp>
00019 #include <algebramix/polynomial_balanced.hpp>
00020
00021 namespace mmx {
00022 #define TMPL template<typename C>
00023
00024
00025
00026
00027
00028 struct schonhage_triadic_threshold {};
00029
00030 template<typename V, typename Th= schonhage_triadic_threshold>
00031 struct polynomial_schonhage_triadic_inc: public V {
00032 typedef typename V::Vec Vec;
00033 typedef typename V::Naive Naive;
00034 typedef typename V::Positive Positive;
00035 typedef polynomial_schonhage_triadic_inc<typename V::No_simd,Th> No_simd;
00036 typedef polynomial_schonhage_triadic_inc<typename V::No_thread,Th> No_thread;
00037 typedef polynomial_schonhage_triadic_inc<typename V::No_scaled,Th> No_scaled;
00038 };
00039
00040 template<typename F, typename V, typename W, typename Th>
00041 struct implementation<F,V,polynomial_schonhage_triadic_inc<W,Th> >:
00042 public implementation<F,V,W> {};
00043
00044 DEFINE_VARIANT_1 (typename V, V,
00045 polynomial_schonhage_triadic,
00046 polynomial_balanced<
00047 polynomial_schonhage_triadic_inc<
00048 polynomial_karatsuba<V> > >)
00049
00050
00051
00052
00053
00054 template<typename V, typename W, typename Th>
00055 struct implementation<polynomial_multiply,V,
00056 polynomial_schonhage_triadic_inc<W,Th> >:
00057 public implementation<polynomial_linear,V>
00058 {
00059 typedef implementation<vector_linear,V> Vec;
00060 typedef implementation<polynomial_linear,V> Pol;
00061 typedef implementation<polynomial_multiply,W> Inner;
00062
00063 private:
00064 TMPL static inline C**
00065 bivariate_new (nat m, nat t) {
00066 nat l = aligned_size<C,V> (m);
00067 C** dest= mmx_new<C*> (t);
00068 for (nat j = 0; j < t; j++)
00069 dest[j]= mmx_new<C> (l);
00070 return dest; }
00071
00072 TMPL static inline void
00073 bivariate_delete (C** dest, nat m, nat t) {
00074 nat l = aligned_size<C,V> (m);
00075 for (nat j = 0; j < t; j++)
00076 mmx_delete<C> (dest[j], l);
00077 mmx_delete<C*> (dest, t); }
00078
00079 TMPL static inline void
00080 bivariate_encode (C** dest, const C*s, nat n, nat m, nat t) {
00081
00082
00083
00084 nat i, j, n2= n << 1, m2= m << 1, t2= t << 1;
00085 for (i = 0, j = 0; i < n2; i += m, j++)
00086 if (n2-i >= m) {
00087 Vec::copy (dest[j], s + i, m);
00088 Vec::clear (dest[j] + m, m);
00089 }
00090 else {
00091 Vec::copy (dest[j], s + i, n2 - i);
00092 Vec::clear (dest[j] + n2 - i, m2 - (n2 - i));
00093 }
00094 for (; j < t2; j++)
00095 Vec::clear (dest[j], m2); }
00096
00097 TMPL static inline void
00098 bivariate_decode (C* dest, const C** s, nat n, nat m, nat t) {
00099
00100
00101
00102 nat i, n2= n << 1, m2= m << 1, t2= t << 1;
00103 Vec::clear (dest, n2);
00104 for (i = 0; i+1 < t2; i++)
00105 Pol::add (dest + i * m, s[i], m2);
00106 Pol::add (dest + i * m, s[i] , m);
00107 Pol::sub (dest , s[i] + m, m);
00108 Pol::sub (dest + n , s[i] + m, m); }
00109
00110 TMPL static inline void
00111 triadic_shift (C* dest, const C* src, nat i, nat m) {
00112
00113 nat m2= m << 1, m3= m2 + m;
00114 i= i % m3;
00115 if (i > m2) {
00116 Vec::copy (dest , src + m3 - i, i - m);
00117 Vec::neg (dest + i - m , src , m3 - i);
00118 Vec::sub (dest + i - m2, src , m3 - i);
00119 }
00120 else if (i > m) {
00121 Vec::neg (dest , src + m2 - i, m);
00122 Vec::neg (dest + m, src + m2 - i, m);
00123 Vec::add (dest , src + m3 - i, i - m);
00124 Vec::add (dest + i, src , m2 - i);
00125 }
00126 else {
00127 Vec::neg (dest , src + m2 - i, i);
00128 Vec::copy (dest + i, src , m2 - i);
00129 Vec::sub (dest + m, src + m2 - i, i); } }
00130
00131 template<typename C>
00132 struct unptr_helper {};
00133
00134 template<typename C>
00135 struct unptr_helper<C*> {
00136 typedef C type; };
00137
00138 template<typename Cp>
00139 struct triadic_roots_helper {
00140
00141 typedef Cp C;
00142 typedef nat U;
00143 typedef typename unptr_helper<Cp>::type CC;
00144 typedef CC S;
00145
00146 static inline nat&
00147 dyn_m () {
00148 static nat m= 0;
00149 return m; }
00150
00151 static inline C&
00152 dyn_temp (nat i) {
00153 static C temp0= NULL;
00154 static C temp1= NULL;
00155 static C temp2= NULL;
00156 static C temp3= NULL;
00157 switch (i) {
00158 case 0: return temp0;
00159 case 1: return temp1;
00160 case 2: return temp2;
00161 case 3: return temp3;
00162 default: ERROR ("index out of range"); return temp0; } }
00163
00164 static inline nat
00165 primitive_root (nat t, nat i) {
00166 if (t == 0) return 1;
00167 i = i % t;
00168 nat m3 = 3 * dyn_m ();
00169 ASSERT (m3 % t == 0, "primitive root out of range");
00170 return i * (m3 / t); }
00171
00172 static U*
00173 create_roots (nat t, const format<U>&) {
00174 nat m2= dyn_m () << 1;
00175 nat l= aligned_size<CC,V> (m2);
00176 for (nat i = 0; i <= 3; i++)
00177 dyn_temp (i)= mmx_new<CC> (l);
00178 U* roots= mmx_new<U> (t);
00179 for (nat i = 0; i < t; i++)
00180 roots[i] = primitive_root (t, digit_mirror_triadic (i, t));
00181 return roots; }
00182
00183 static U*
00184 create_stoor (nat t, const format<U>&) {
00185 U* stoor= mmx_new<U> (t);
00186 for (nat i = 0; i < t; i++)
00187 stoor[i]= primitive_root (t, i == 0 ?
00188 0 : t - digit_mirror_triadic (i, t));
00189 return stoor; }
00190
00191 static void
00192 destroy_roots (U* u, nat t) {
00193 nat m2= dyn_m () << 1;
00194 nat l= aligned_size<CC,V> (m2);
00195 for (nat i = 0; i <= 3; i++) {
00196 C& temp= dyn_temp (i);
00197 if (temp != NULL) {
00198 mmx_delete<CC> (temp, l);
00199 temp= NULL;
00200 }
00201 }
00202 mmx_delete<U> (u, t); }
00203
00204 static inline void
00205 dfft_cross (C* c1, C* c2, C* c3, const U* u1, const U* u2, const U* u3) {
00206 C temp0= dyn_temp (0);
00207 C temp1= dyn_temp (1);
00208 C temp2= dyn_temp (2);
00209 C temp3= dyn_temp (3);
00210 nat m= dyn_m (), m2= m << 1;
00211
00212 triadic_shift (temp0, *c3, *u3, m);
00213 Vec::add (temp0, *c2, m2);
00214 triadic_shift (temp3, temp0, *u3, m);
00215
00216 triadic_shift (temp0, *c3, *u2, m);
00217 Vec::add (temp0, *c2, m2);
00218 triadic_shift (temp2, temp0, *u2, m);
00219
00220 triadic_shift (temp0, *c3, *u1, m);
00221 Vec::add (temp0, *c2, m2);
00222 triadic_shift (temp1, temp0, *u1, m);
00223
00224 Vec::add (*c3, *c1, temp3, m2);
00225 Vec::add (*c2, *c1, temp2, m2);
00226 Vec::add (*c1, temp1, m2); }
00227
00228 static inline void
00229 ifft_cross (C* c1, C* c2, C* c3, const U* u1, const U* u2, const U* u3) {
00230 C temp0= dyn_temp (0);
00231 C temp1= dyn_temp (1);
00232 C temp2= dyn_temp (2);
00233 C temp3= dyn_temp (3);
00234 nat m= dyn_m (), m2= m << 1;
00235
00236 triadic_shift (temp1, *c1, *u1, m);
00237 Vec::add (*c1, *c2, m2);
00238 Vec::add (*c1, *c3, m2);
00239
00240 triadic_shift (temp2, *c2, *u2, m);
00241 triadic_shift (temp3, *c3, *u3, m);
00242 triadic_shift (temp0, *c2, *u3, m);
00243 Vec::add (*c2, temp1, temp2, m2);
00244 Vec::add (*c2, temp3, m2);
00245
00246 triadic_shift (temp3, *c3, *u2, m);
00247 Vec::add (temp0, temp3, m2);
00248 Vec::add (temp0, temp1, m2);
00249 triadic_shift (*c3, temp0, *u1, m); }
00250
00251 static inline void
00252 fft_shift (C* dest, S v, nat t) {
00253 nat m2= dyn_m () << 1;
00254 for (nat i = 0; i < t; i++)
00255 Vec::mul (dest[i], v, m2); }
00256 };
00257
00258 struct triadic_roots {
00259 template<typename C>
00260 struct helper {
00261 typedef triadic_roots_helper<C> roots_type; };
00262 };
00263
00264 TMPL static void
00265 direct_transform (C** b, nat m, nat t) {
00266 triadic_roots_helper<C*>::dyn_m ()= m;
00267 fft_triadic_naive_transformer<C*, triadic_roots>
00268 ffter (t, format<C*> ());
00269 ffter.direct_transform_triadic (b); }
00270
00271 TMPL static void
00272 inverse_transform (C** b, nat m, nat t, bool shift=true) {
00273 triadic_roots_helper<C*>::dyn_m ()= m;
00274 fft_triadic_naive_transformer<C*, triadic_roots>
00275 ffter (t, format<C*> ());
00276 ffter.inverse_transform_triadic (b,shift); }
00277
00278 TMPL static void
00279 variable_dilate (C** b, nat eta, nat m, nat t) {
00280
00281 nat m2= m << 1, m3= m2 + m;
00282 nat l= aligned_size<C,V> (m2);
00283 C* temp= mmx_new<C> (l);
00284 nat w= 0;
00285 for (nat i = 0; i < t; i++) {
00286 Vec::copy (temp, b[i], m2);
00287 triadic_shift (b[i], temp, w, m);
00288 w= (w + eta) % m3;
00289 }
00290 mmx_delete<C> (temp, l); }
00291
00292 TMPL static void
00293 bivariate_mod (C** dest, const C** h, nat w, nat m, nat t) {
00294
00295 nat m2= m << 1, m3= m2 + m;
00296 w= w % m3;
00297 for (nat i = 0; i < t; i++) {
00298 triadic_shift (dest[i], h[i+t], w, m);
00299 Vec::add (dest[i], h[i], m2); } }
00300
00301 TMPL static void
00302 bivariate_crt (C** h, const C** h1, const C** h2, nat w,
00303 nat m, nat t, bool shift=true) {
00304
00305 nat m2= m << 1, m3= m2 + m, t2= t << 1;
00306 w= w % m3;
00307 nat w2= (w << 1) % m3;
00308 nat l= aligned_size<C,V> (m2);
00309 C* temp= mmx_new<C> (l);
00310 for (nat i = 0; i < t; i++)
00311 Vec::sub (h[i+t], h2[i], h1[i], m2);
00312 for (nat i = 0; i < t; i++) {
00313 triadic_shift (h[i], h1[i], w2, m);
00314 triadic_shift (temp, h2[i], w , m);
00315 Vec::sub (h[i], temp, m2);
00316 }
00317 for (nat i = 0; i < t2; i++) {
00318 triadic_shift (temp, h[i], w, m);
00319 Vec::add (h[i], temp, m2);
00320 Vec::add (h[i], temp, m2);
00321 if (shift)
00322 Vec::mul (h[i], invert (C(3)), m2);
00323 }
00324 mmx_delete<C> (temp, l); }
00325
00326 public:
00327 TMPL static inline nat
00328 mul_triadic (C* dest, const C* src, nat n, bool shift=true) {
00329
00330 nat n2= n << 1;
00331 nat u = next_power_of_three (n);
00332 ASSERT (u != 0, "maximum size exceeded");
00333 ASSERT (n == u, "power of three expected");
00334 if (n <= max ((nat) 9, (nat) Threshold(C,Th))) {
00335 nat l= aligned_size<C,V> (2*n2-1);
00336 C* temp= mmx_new<C> (l);
00337 Inner::mul (temp, dest, src, n2, n2);
00338 Vec::copy (dest , temp , n2);
00339 Vec::add (dest , temp + n2 + n, n - 1);
00340 Vec::sub (dest , temp + n2 , n);
00341 Vec::sub (dest + n, temp + n2 , n);
00342 mmx_delete<C> (temp, l);
00343 return 0;
00344 }
00345 else {
00346 nat r = 0;
00347 nat k = log_3 (n);
00348 nat m = binpow (3, (k+1) >> 1), m2= m << 1, m3= m2 + m;
00349 nat t = n / m, t2= t << 1;
00350 nat eta= (t == m) ? 1 : 3;
00351
00352 C** hh= bivariate_new<C> (m2, t2);
00353 C** gg= bivariate_new<C> (m2, t2);
00354 C** h1= bivariate_new<C> (m2, t);
00355 C** h2= bivariate_new<C> (m2, t);
00356 C** g = bivariate_new<C> (m2, t);
00357 bivariate_encode (hh, src , n, m, t);
00358 bivariate_encode (gg, dest, n, m, t);
00359 for (nat j = 1; j <= 2; j++) {
00360 C** h= j == 1 ? h1 : h2;
00361 bivariate_mod (h, (const C**) hh, (j*eta*t) % m3, m, t);
00362 bivariate_mod (g, (const C**) gg, (j*eta*t) % m3, m, t);
00363 variable_dilate (h, (j*eta) % m3, m, t);
00364 variable_dilate (g, (j*eta) % m3, m, t);
00365 direct_transform (h, m, t);
00366 direct_transform (g, m, t);
00367 for (nat i = 0; i < t; i++)
00368 r = mul_triadic (h[i], g[i], m, shift);
00369 inverse_transform (h, m, t, shift);
00370 variable_dilate (h, m3 - ((j*eta) % m3), m, t);
00371 }
00372 bivariate_delete<C> (g , m2 , t);
00373 bivariate_delete<C> (gg, m2, t2);
00374 bivariate_crt (hh, (const C**) h1, (const C**) h2, (eta*t) % m3, m, t, shift);
00375 bivariate_delete<C> (h1, m2, t);
00376 bivariate_delete<C> (h2, m2, t);
00377 bivariate_decode (dest, (const C**) hh, n, m, t);
00378 bivariate_delete<C> (hh, m2, t2);
00379 return (shift) ? 0 : r + 1 + (k >> 1); }}
00380
00381 TMPL static inline nat
00382 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2, bool shift=true) {
00383 nat len = n1 + n2 - 1;
00384 nat n = next_power_of_three ((len + 1) >> 1);
00385 nat nn= n << 1;
00386 nat l = aligned_size<C,V> (nn);
00387 C* t1 = mmx_new<C> (l);
00388 C* t2 = mmx_new<C> (l);
00389 Vec::copy (t1, s1, n1); Vec::clear (t1 + n1, nn - n1);
00390 Vec::copy (t2, s2, n2); Vec::clear (t2 + n2, nn - n2);
00391 nat k = mul_triadic (t1, t2, n, shift);
00392 Vec::copy (dest, t1, len);
00393 mmx_delete<C> (t1, l);
00394 mmx_delete<C> (t2, l);
00395 return k; }
00396
00397 };
00398
00399 #undef TMPL
00400 }
00401 #endif //__MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP