00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #ifndef __MMX__POLYNOMIAL_SCHONHAGE_STRASSEN__HPP
00016 #define __MMX__POLYNOMIAL_SCHONHAGE_STRASSEN__HPP
00017 #include <algebramix/polynomial.hpp>
00018 #include <algebramix/polynomial_dicho.hpp>
00019 #include <algebramix/polynomial_tft.hpp>
00020
00021 namespace mmx {
00022 #define TMPL template<typename C>
00023
00024
00025
00026
00027
00028 struct schonhage_strassen_threshold {};
00029
00030 template<typename V, typename T= schonhage_strassen_threshold>
00031 struct polynomial_schonhage_strassen_inc: public V {
00032 typedef typename V::Vec Vec;
00033 typedef typename V::Naive Naive;
00034 typedef typename V::Positive Positive;
00035 typedef polynomial_schonhage_strassen_inc<typename V::No_simd,T> No_simd;
00036 typedef polynomial_schonhage_strassen_inc<typename V::No_thread,T> No_thread;
00037 typedef polynomial_schonhage_strassen_inc<typename V::No_scaled,T> No_scaled;
00038 };
00039
00040 template<typename F, typename V, typename W, typename Th>
00041 struct implementation<F,V,polynomial_schonhage_strassen_inc<W,Th> >:
00042 public implementation<F,V,W> {};
00043
00044 DEFINE_VARIANT_1 (typename V, V,
00045 polynomial_schonhage_strassen,
00046 polynomial_balanced_tft<
00047 polynomial_schonhage_strassen_inc<
00048 polynomial_karatsuba<V> > >)
00049
00050
00051
00052
00053
00054 template<typename V, typename W, typename Th>
00055 struct implementation<polynomial_multiply,V,
00056 polynomial_schonhage_strassen_inc<W,Th> >:
00057 public implementation<polynomial_linear,V>
00058 {
00059 typedef implementation<vector_linear,V> Vec;
00060 typedef implementation<polynomial_linear,V> Pol;
00061 typedef implementation<polynomial_multiply,W> Inner;
00062
00063 private:
00064 TMPL static inline C**
00065 bivariate_new (nat m2, nat t) {
00066 nat l = aligned_size<C ,V> (m2);
00067 nat ly= aligned_size<C*,V> (t);
00068 C** dest= mmx_new<C*> (ly);
00069 for (nat j = 0; j < t; j++)
00070 dest[j]= mmx_new<C> (l);
00071 return dest; }
00072
00073 TMPL static inline void
00074 bivariate_delete (C** dest, nat m2, nat t) {
00075 nat l = aligned_size<C ,V> (m2);
00076 nat ly= aligned_size<C*,V> (t);
00077 for (nat j = 0; j < t; j++)
00078 mmx_delete<C> (dest[j], l);
00079 mmx_delete<C*> (dest, ly); }
00080
00081 TMPL static inline void
00082 bivariate_encode (C** dest, const C*s, nat n, nat m, nat t) {
00083
00084
00085
00086 nat i, j, m2= m << 1;
00087 for (i = 0, j = 0; i < n; i += m, j++)
00088 if (n-i >= m) {
00089 Vec::copy (dest[j], s + i, m);
00090 Vec::clear (dest[j] + m, m);
00091 }
00092 else {
00093 Vec::copy (dest[j], s + i, n - i);
00094 Vec::clear (dest[j] + n - i, m2 - (n - i));
00095 }
00096 for (; j < t; j++)
00097 Vec::clear (dest[j], m2); }
00098
00099 TMPL static inline void
00100 bivariate_decode (C* dest, const C** s, nat n, nat m, nat t) {
00101
00102
00103 nat i, m2= m << 1;
00104 Vec::clear (dest, n);
00105 for (i = 0; i+1 < t; i++)
00106 Pol::add (dest + i * m, s[i], m2);
00107 Pol::add (dest + i * m, s[i] , m);
00108 Pol::sub (dest , s[i] + m, m); }
00109
00110 TMPL static inline void
00111 negative_cyclic_shift (C* dest, const C* src, nat m2, nat i) {
00112
00113 bool negate= ((i / m2) & 1);
00114 i= i % m2;
00115 if (negate) {
00116 Vec::neg (dest + i, src , m2 - i);
00117 Vec::copy (dest , src + m2 - i, i);
00118 }
00119 else {
00120 Vec::copy (dest + i, src , m2 - i);
00121 Vec::neg (dest , src + m2 - i, i); } }
00122
00123 TMPL static inline void
00124 negative_cyclic_shift (C* dest, nat m2, nat i) {
00125 nat l= aligned_size<C,V> (m2);
00126 C* temp= mmx_new<C> (l);
00127 Vec::copy (temp, dest, m2);
00128 negative_cyclic_shift (dest, temp, m2, i);
00129 mmx_delete<C> (temp, l); }
00130
00131 template<typename C>
00132 struct unptr_helper {};
00133
00134 template<typename C>
00135 struct unptr_helper<C*> {
00136 typedef C type; };
00137
00138 template<typename Cp>
00139 struct negative_cyclic_roots_helper {
00140
00141 typedef Cp C;
00142 typedef nat U;
00143 typedef typename unptr_helper<Cp>::type CC;
00144 typedef CC S;
00145
00146 static inline nat&
00147 dyn_modulus () {
00148 static nat m2= 0;
00149 return m2; }
00150
00151 static inline C&
00152 get_temp () {
00153 static C temp= NULL;
00154 return temp; }
00155
00156 static inline nat
00157 primitive_root (nat t, nat i) {
00158 ASSERT (t != 0, "unexpected zero root order");
00159 i = i % t;
00160 nat m2 = dyn_modulus ();
00161 return i * (m2 << 1) / t; }
00162
00163 static U*
00164 create_roots (nat t, const format<U>&) {
00165 nat m2= dyn_modulus ();
00166 nat l= aligned_size<CC,V> (m2);
00167 get_temp ()= mmx_new<CC> (l);
00168 U* roots= mmx_new<U> (t);
00169 for (nat i = 0; i < t; i += 2) {
00170 roots[i] = primitive_root (t, bit_mirror (i, t));
00171 roots[i+1]= primitive_root (t, i == 0 ? 0 : t - bit_mirror (i, t));
00172 }
00173 return roots; }
00174
00175 static void
00176 destroy_roots (U* u, nat t) {
00177 C& temp= get_temp ();
00178 if (temp != NULL) {
00179 nat m2= dyn_modulus ();
00180 nat l= aligned_size<CC,V> (m2);
00181 mmx_delete<CC> (temp, l);
00182 temp= NULL;
00183 }
00184 mmx_delete<U> (u, t); }
00185
00186 static inline void
00187 fft_cross (C* c1, C* c2) {
00188 C temp= get_temp ();
00189 nat m2= dyn_modulus ();
00190 Vec::copy (temp, *c2, m2);
00191 Vec::sub (*c2, *c1, temp, m2);
00192 Vec::add (*c1, temp, m2); }
00193
00194 static inline void
00195 dfft_cross (C* c1, C* c2, const U* u) {
00196 C temp= get_temp ();
00197 nat m2= dyn_modulus ();
00198 negative_cyclic_shift (temp, *c2, m2, *u);
00199 Vec::sub (*c2, *c1, temp, m2);
00200 Vec::add (*c1, temp, m2); }
00201
00202 static inline void
00203 ifft_cross (C* c1, C* c2, const U* u) {
00204 C temp= get_temp ();
00205 nat m2= dyn_modulus ();
00206 Vec::copy (temp, *c1, m2);
00207 Vec::add (*c1, *c2, m2);
00208 Vec::sub (*c2, temp, m2);
00209 negative_cyclic_shift (*c2, m2, (*u) + m2); }
00210
00211 static inline void
00212 dtft_cross (C* c1, C* c2) {
00213 static CC h= invert (CC(2));
00214 nat m2= dyn_modulus ();
00215 fft_cross (c1, c2);
00216 Vec::mul (*c1, h, m2);
00217 Vec::mul (*c2, h, m2); }
00218
00219 static inline void
00220 dtft_cross (C* c1, C* c2, const U* u) {
00221 static CC h= invert (CC(2));
00222 nat m2= dyn_modulus ();
00223 dfft_cross (c1, c2, u);
00224 Vec::mul (*c1, h, m2);
00225 Vec::mul (*c2, h, m2); }
00226
00227 static inline void
00228 itft_flip (C* c1, C* c2, const U* u) {
00229 static CC h= invert (CC(2));
00230 C temp= get_temp ();
00231 nat m2= dyn_modulus ();
00232 negative_cyclic_shift (temp, *c2, m2, *u);
00233 Vec::add (*c1, *c1, m2);
00234 Vec::sub (*c1, temp, m2);
00235 Vec::sub (*c2, *c1, temp, m2);
00236 Vec::mul (*c2, h, m2); }
00237
00238 static inline void
00239 itft_flip (C* c1, C* c2) {
00240 static CC h= invert (CC(2));
00241 nat m2= dyn_modulus ();
00242 Vec::add (*c1, *c1, m2);
00243 Vec::sub (*c1, *c2, m2);
00244 Vec::sub (*c2, *c1, *c2, m2);
00245 Vec::mul (*c2, h, m2); }
00246
00247 struct fft_mul_sc_op : mul_op {
00248 static inline void
00249 set_op (C& x, const S& y) {
00250 Vec::mul (x, y, dyn_modulus ()); }
00251 };
00252 };
00253
00254 TMPL
00255 struct negative_cyclic_roots {
00256 typedef negative_cyclic_roots_helper<C> roots_type;
00257 };
00258
00259 TMPL static void
00260 direct_transform (C** b, nat m2, nat t) {
00261 negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2;
00262 fft_naive_transformer<C*, negative_cyclic_roots<C*> >
00263 ffter (t, format<C*> ());
00264 ffter.direct_transform (b); }
00265
00266 TMPL static void
00267 inverse_transform (C** b, nat m2, nat t, bool divide=true) {
00268 negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2;
00269 fft_naive_transformer<C*, negative_cyclic_roots<C*> >
00270 ffter (t, format<C*> ());
00271 ffter.inverse_transform (b, divide); }
00272
00273 TMPL static void
00274 direct_transform_truncated (C** b, nat m2, nat len) {
00275 negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2;
00276 fft_truncated_transformer<C*,
00277 fft_naive_transformer<C*, negative_cyclic_roots<C*> > >
00278 ffter (len, format<C*> ());
00279 ffter.dtft (b, 1, len, 0); }
00280
00281 TMPL static void
00282 inverse_transform_truncated (C** b, nat m2, nat len) {
00283 nat i;
00284 for (i = len; i < next_power_of_two (len); i++)
00285 Vec::clear (b[i], m2);
00286 negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2;
00287 fft_truncated_transformer<C*,
00288 fft_naive_transformer<C*, negative_cyclic_roots<C*> > >
00289 ffter (len, format<C*> ());
00290 ffter.itft (b, 1, len, 0);
00291 C h= invert (C(next_power_of_two (len)));
00292 for (i = 0; i < len; i++)
00293 Vec::mul (b[i], h, m2);
00294 for (; i < next_power_of_two (len); i++)
00295 Vec::clear (b[i], m2); }
00296
00297 TMPL static void
00298 variable_dilate (C** b, nat eta, nat m2, nat t) {
00299
00300 nat m4= m2 << 1;
00301 nat w= 0;
00302 for (nat i = 0; i < t; i++) {
00303 negative_cyclic_shift (b[i], m2, w);
00304 w= (w + eta) % m4; } }
00305
00306 public:
00307 TMPL static inline nat
00308 mul_negative_cyclic (C* dest, const C* src, nat n, bool shift=true) {
00309
00310 nat u= next_power_of_two (n);
00311 ASSERT (u != 0, "maximum size exceeded");
00312 ASSERT (n == u, "power of two expected");
00313 if (n <= max ((nat) 4, (nat) Threshold(C,Th))) {
00314 nat l= aligned_size<C,V> (2*n-1);
00315 C* temp= mmx_new<C> (l);
00316 Inner::mul (temp, dest, src, n, n);
00317 Vec::copy (dest, temp, n);
00318 Vec::sub (dest, temp + n, n - 1);
00319 mmx_delete<C> (temp, l);
00320 return 0;
00321 }
00322 else {
00323 nat r = 0;
00324 nat k = log_2 (n);
00325 nat m = (nat) 1 << (k >> 1);
00326 nat t = u / m;
00327 nat m2= m << 1;
00328 C** b1= bivariate_new<C> (m2, t);
00329 C** b2= bivariate_new<C> (m2, t);
00330 bivariate_encode (b1, src , n, m, t);
00331 bivariate_encode (b2, dest, n, m, t);
00332 nat eta= (t == m2) ? 1 : 2;
00333 variable_dilate (b1, eta, m2, t);
00334 variable_dilate (b2, eta, m2, t);
00335 direct_transform (b1, m2, t);
00336 direct_transform (b2, m2, t);
00337 for (nat i = 0; i < t; i++)
00338 r= mul_negative_cyclic (b1[i], b2[i], m2, shift);
00339 bivariate_delete<C> (b2, m2, t);
00340 inverse_transform (b1, m2, t, shift);
00341 variable_dilate (b1, (m2 << 1) - eta, m2, t);
00342 bivariate_decode (dest, (const C**) b1, n, m, t);
00343 bivariate_delete<C> (b1, m2, t);
00344 return r + ((k+1) >> 1); } }
00345
00346 TMPL static inline void
00347 mul_negative_cyclic_truncated (C* dest, const C* src, nat len) {
00348
00349
00350 nat n= next_power_of_two (len);
00351 ASSERT (n != 0, "maximum size exceeded");
00352 if (n <= max ((nat) 4, (nat) Threshold(C,Th))) {
00353 nat l= aligned_size<C,V> (2*n-1);
00354 C* temp= mmx_new<C> (l);
00355 Inner::mul (temp, dest, src, n, n);
00356 Vec::copy (dest, temp, n);
00357 Vec::sub (dest, temp + n, n - 1);
00358 mmx_delete<C> (temp, l);
00359 return;
00360 }
00361 else {
00362 nat k = log_2 (n);
00363 nat m = (nat) 1 << (k >> 1);
00364 nat t = n / m;
00365 nat r = (len + m - 1) / m;
00366 nat m2= m << 1;
00367 C** b1= bivariate_new<C> (m2, t);
00368 C** b2= bivariate_new<C> (m2, t);
00369 bivariate_encode (b1, src , n, m, t);
00370 bivariate_encode (b2, dest, n, m, t);
00371 nat eta= (t == m2) ? 1 : 2;
00372 variable_dilate (b1, eta, m2, r);
00373 variable_dilate (b2, eta, m2, r);
00374 direct_transform_truncated (b1, m2, r);
00375 direct_transform_truncated (b2, m2, r);
00376 for (nat i = 0; i < r; i++)
00377 mul_negative_cyclic (b1[i], b2[i], m2);
00378 bivariate_delete<C> (b2, m2, t);
00379 inverse_transform_truncated (b1, m2, r);
00380 variable_dilate (b1, (m2 << 1) - eta, m2, r);
00381 bivariate_decode (dest, (const C**) b1, n, m, t);
00382 bivariate_delete<C> (b1, m2, t); } }
00383
00384 TMPL static inline nat
00385 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2, bool shift=true) {
00386 nat ret = 0;
00387 nat len = n1 + n2 - 1;
00388 nat n = next_power_of_two (len);
00389 nat l = aligned_size<C,V> (n);
00390 C* t1 = mmx_new<C> (l);
00391 C* t2 = mmx_new<C> (l);
00392 Vec::copy (t1, s1, n1); Vec::clear (t1 + n1, n - n1);
00393 Vec::copy (t2, s2, n2); Vec::clear (t2 + n2, n - n2);
00394 if (shift)
00395 mul_negative_cyclic_truncated (t1, t2, len);
00396 else
00397 ret= mul_negative_cyclic (t1, t2, n, shift);
00398 Vec::copy (dest, t1, len);
00399 mmx_delete<C> (t1, l);
00400 mmx_delete<C> (t2, l);
00401 return ret; }
00402
00403 };
00404
00405 #undef TMPL
00406 }
00407 #endif //__MMX__POLYNOMIAL_SCHONHAGE_STRASSEN__HPP