00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 #ifndef __MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP
00016 #define __MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP
00017 #include <algebramix/fft_triadic_naive.hpp>
00018 #include <algebramix/polynomial_dicho.hpp>
00019 #include <algebramix/polynomial_balanced.hpp>
00020 
00021 namespace mmx {
00022 #define TMPL template<typename C>
00023 
00024 
00025 
00026 
00027 
00028 struct schonhage_triadic_threshold {};
00029 
00030 template<typename V, typename Th= schonhage_triadic_threshold>
00031 struct polynomial_schonhage_triadic_inc: public V {
00032   typedef typename V::Vec Vec;
00033   typedef typename V::Naive Naive;
00034   typedef typename V::Positive Positive;
00035   typedef polynomial_schonhage_triadic_inc<typename V::No_simd,Th> No_simd;
00036   typedef polynomial_schonhage_triadic_inc<typename V::No_thread,Th> No_thread;
00037   typedef polynomial_schonhage_triadic_inc<typename V::No_scaled,Th> No_scaled;
00038 };
00039 
00040 template<typename F, typename V, typename W, typename Th>
00041 struct implementation<F,V,polynomial_schonhage_triadic_inc<W,Th> >:
00042   public implementation<F,V,W> {};
00043 
00044 DEFINE_VARIANT_1 (typename V, V,
00045                   polynomial_schonhage_triadic,
00046                   polynomial_balanced<
00047                     polynomial_schonhage_triadic_inc<
00048                       polynomial_karatsuba<V> > >)
00049 
00050 
00051 
00052 
00053 
00054 template<typename V, typename W, typename Th>
00055 struct implementation<polynomial_multiply,V,
00056                       polynomial_schonhage_triadic_inc<W,Th> >:
00057   public implementation<polynomial_linear,V>
00058 {
00059   typedef implementation<vector_linear,V> Vec;
00060   typedef implementation<polynomial_linear,V> Pol;
00061   typedef implementation<polynomial_multiply,W> Inner;
00062 
00063 private:  
00064   TMPL static inline C**
00065   bivariate_new (nat m, nat t) {
00066     nat l = aligned_size<C,V> (m);
00067     C** dest= mmx_new<C*> (t);
00068     for (nat j = 0; j < t; j++)
00069       dest[j]= mmx_new<C> (l);
00070     return dest; }
00071 
00072   TMPL static inline void
00073   bivariate_delete (C** dest, nat m, nat t) {
00074     nat l = aligned_size<C,V> (m);
00075     for (nat j = 0; j < t; j++)
00076       mmx_delete<C> (dest[j], l);
00077     mmx_delete<C*> (dest, t); }
00078 
00079   TMPL static inline void
00080   bivariate_encode (C** dest, const C*s, nat n, nat m, nat t) {
00081     
00082     
00083     
00084     nat i, j, n2= n << 1, m2= m << 1, t2= t << 1;
00085     for (i = 0, j = 0; i < n2; i += m, j++)
00086       if (n2-i >= m) {
00087         Vec::copy  (dest[j], s + i, m);
00088         Vec::clear (dest[j] + m, m);
00089       }
00090       else {
00091         Vec::copy  (dest[j], s + i, n2 - i);
00092         Vec::clear (dest[j] + n2 - i, m2 - (n2 - i));
00093       }
00094     for (; j < t2; j++)
00095       Vec::clear (dest[j], m2); }
00096   
00097   TMPL static inline void
00098   bivariate_decode (C* dest, const C** s, nat n, nat m, nat t) {
00099     
00100     
00101     
00102     nat i, n2= n << 1, m2= m << 1, t2= t << 1;
00103     Vec::clear (dest, n2);
00104     for (i = 0; i+1 < t2; i++)
00105       Pol::add (dest + i * m, s[i], m2);
00106     Pol::add (dest + i * m, s[i]    , m);
00107     Pol::sub (dest        , s[i] + m, m);
00108     Pol::sub (dest + n    , s[i] + m, m); }
00109 
00110   TMPL static inline void
00111   triadic_shift (C* dest, const C* src, nat i, nat m) {
00112     
00113     nat m2= m << 1, m3= m2 + m;
00114     i= i % m3;
00115     if (i > m2) {
00116       Vec::copy (dest         , src + m3 - i, i - m);
00117       Vec::neg  (dest + i - m , src         , m3 - i);
00118       Vec::sub  (dest + i - m2, src         , m3 - i);
00119     }
00120     else if (i > m) {
00121       Vec::neg  (dest    , src + m2 - i, m);
00122       Vec::neg  (dest + m, src + m2 - i, m);
00123       Vec::add  (dest    , src + m3 - i, i - m);
00124       Vec::add  (dest + i, src         , m2 - i);
00125     }
00126     else {
00127       Vec::neg  (dest    , src + m2 - i, i);
00128       Vec::copy (dest + i, src         , m2 - i);
00129       Vec::sub  (dest + m, src + m2 - i, i); } }
00130   
00131   template<typename C>
00132   struct unptr_helper {};
00133 
00134   template<typename C>
00135   struct unptr_helper<C*> {
00136     typedef C type; };
00137 
00138   template<typename Cp>
00139   struct triadic_roots_helper {
00140     
00141     typedef Cp  C;
00142     typedef nat U;
00143     typedef typename unptr_helper<Cp>::type CC;
00144     typedef CC  S;
00145 
00146     static inline nat&
00147     dyn_m () {
00148       static nat m= 0;
00149       return m; }
00150 
00151     static inline C&
00152     dyn_temp (nat i) {
00153       static C temp0= NULL;
00154       static C temp1= NULL;
00155       static C temp2= NULL;
00156       static C temp3= NULL;
00157       switch (i) {
00158       case 0: return temp0;
00159       case 1: return temp1;
00160       case 2: return temp2;
00161       case 3: return temp3;
00162       default: ERROR ("index out of range"); return temp0; } }
00163 
00164     static inline nat
00165     primitive_root (nat t, nat i) {
00166       if (t == 0) return 1;
00167       i = i % t;
00168       nat m3 = 3 * dyn_m ();
00169       ASSERT (m3 % t == 0, "primitive root out of range");
00170       return i * (m3 / t); }
00171 
00172     static U*
00173     create_roots (nat t, const format<U>&) {
00174       nat m2= dyn_m () << 1;
00175       nat l= aligned_size<CC,V> (m2);
00176       for (nat i = 0; i <= 3; i++)
00177         dyn_temp (i)= mmx_new<CC> (l);
00178       U* roots= mmx_new<U> (t);
00179       for (nat i = 0; i < t; i++)
00180         roots[i]  = primitive_root (t, digit_mirror_triadic (i, t));
00181       return roots; }
00182 
00183     static U*
00184     create_stoor (nat t, const format<U>&) {
00185       U* stoor= mmx_new<U> (t);
00186       for (nat i = 0; i < t; i++)
00187         stoor[i]= primitive_root (t, i == 0 ?
00188                                   0 : t - digit_mirror_triadic (i, t));
00189       return stoor; }
00190 
00191     static void
00192     destroy_roots (U* u, nat t) {
00193       nat m2= dyn_m () << 1;
00194       nat l= aligned_size<CC,V> (m2);
00195       for (nat i = 0; i <= 3; i++) {
00196         C& temp= dyn_temp (i);
00197         if (temp != NULL) {
00198           mmx_delete<CC> (temp, l);
00199           temp= NULL;
00200         }
00201       }
00202       mmx_delete<U> (u, t); }
00203 
00204     static inline void
00205     dfft_cross (C* c1, C* c2, C* c3, const U* u1, const U* u2, const U* u3) {
00206       C temp0= dyn_temp (0);
00207       C temp1= dyn_temp (1);
00208       C temp2= dyn_temp (2);
00209       C temp3= dyn_temp (3);
00210       nat m= dyn_m (), m2= m << 1;
00211 
00212       triadic_shift (temp0, *c3, *u3, m);
00213       Vec::add (temp0, *c2, m2);
00214       triadic_shift (temp3, temp0, *u3, m);
00215 
00216       triadic_shift (temp0, *c3, *u2, m);
00217       Vec::add (temp0, *c2, m2);
00218       triadic_shift (temp2, temp0, *u2, m);
00219 
00220       triadic_shift (temp0, *c3, *u1, m);
00221       Vec::add (temp0, *c2, m2);
00222       triadic_shift (temp1, temp0, *u1, m);
00223 
00224       Vec::add (*c3, *c1, temp3, m2);
00225       Vec::add (*c2, *c1, temp2, m2);
00226       Vec::add (*c1, temp1, m2); }
00227 
00228     static inline void
00229     ifft_cross (C* c1, C* c2, C* c3, const U* u1, const U* u2, const U* u3) {
00230       C temp0= dyn_temp (0);
00231       C temp1= dyn_temp (1);
00232       C temp2= dyn_temp (2);
00233       C temp3= dyn_temp (3);
00234       nat m= dyn_m (), m2= m << 1;
00235 
00236       triadic_shift (temp1, *c1, *u1, m);
00237       Vec::add (*c1, *c2, m2);
00238       Vec::add (*c1, *c3, m2);
00239 
00240       triadic_shift (temp2, *c2, *u2, m);
00241       triadic_shift (temp3, *c3, *u3, m);
00242       triadic_shift (temp0, *c2, *u3, m);
00243       Vec::add (*c2, temp1, temp2, m2);
00244       Vec::add (*c2, temp3, m2);
00245 
00246       triadic_shift (temp3, *c3, *u2, m);
00247       Vec::add (temp0, temp3, m2);
00248       Vec::add (temp0, temp1, m2);
00249       triadic_shift (*c3, temp0, *u1, m); }
00250 
00251     static inline void
00252     fft_shift (C* dest, S v, nat t) {
00253       nat m2= dyn_m () << 1;
00254       for (nat i = 0; i < t; i++)
00255         Vec::mul (dest[i], v, m2); }
00256   };
00257 
00258   struct triadic_roots {
00259     template<typename C>
00260     struct helper {
00261       typedef triadic_roots_helper<C> roots_type; };
00262   };
00263 
00264   TMPL static void
00265   direct_transform (C** b, nat m, nat t) {
00266     triadic_roots_helper<C*>::dyn_m ()= m;
00267     fft_triadic_naive_transformer<C*, triadic_roots>
00268       ffter (t, format<C*> ());
00269     ffter.direct_transform_triadic (b); } 
00270 
00271   TMPL static void
00272   inverse_transform (C** b, nat m, nat t, bool shift=true) {
00273     triadic_roots_helper<C*>::dyn_m ()= m;
00274     fft_triadic_naive_transformer<C*, triadic_roots>
00275       ffter (t, format<C*> ());
00276     ffter.inverse_transform_triadic (b,shift); } 
00277 
00278   TMPL static void
00279   variable_dilate (C** b, nat eta, nat m, nat t) {
00280     
00281     nat m2= m << 1, m3= m2 + m;
00282     nat l= aligned_size<C,V> (m2);
00283     C* temp= mmx_new<C> (l);
00284     nat w= 0;
00285     for (nat i = 0; i < t; i++) {
00286       Vec::copy (temp, b[i], m2);
00287       triadic_shift (b[i], temp, w, m);
00288       w= (w + eta) % m3;
00289     }
00290     mmx_delete<C> (temp, l); }
00291 
00292   TMPL static void
00293   bivariate_mod (C** dest, const C** h, nat w, nat m, nat t) {
00294     
00295     nat m2= m << 1, m3= m2 + m;
00296     w= w % m3;
00297     for (nat i = 0; i < t; i++) {
00298       triadic_shift (dest[i], h[i+t], w, m);
00299       Vec::add (dest[i], h[i], m2); } }
00300 
00301   TMPL static void
00302   bivariate_crt (C** h, const C** h1, const C** h2, nat w,
00303                  nat m, nat t, bool shift=true) {
00304     
00305     nat m2= m << 1, m3= m2 + m, t2= t << 1;
00306     w= w % m3;
00307     nat w2= (w << 1) % m3;
00308     nat l= aligned_size<C,V> (m2);
00309     C* temp= mmx_new<C> (l);
00310     for (nat i = 0; i < t; i++) 
00311       Vec::sub (h[i+t], h2[i], h1[i], m2);
00312     for (nat i = 0; i < t; i++) {
00313       triadic_shift (h[i], h1[i], w2, m);
00314       triadic_shift (temp, h2[i], w , m);
00315       Vec::sub (h[i], temp, m2);
00316     }
00317     for (nat i = 0; i < t2; i++) {
00318       triadic_shift (temp, h[i], w, m);
00319       Vec::add (h[i], temp, m2);
00320       Vec::add (h[i], temp, m2);
00321       if (shift)
00322         Vec::mul (h[i], invert (C(3)), m2);
00323     }
00324     mmx_delete<C> (temp, l); }
00325   
00326 public:
00327   TMPL static inline nat
00328   mul_triadic (C* dest, const C* src, nat n, bool shift=true) {
00329     
00330     nat n2= n << 1;
00331     nat u = next_power_of_three (n);
00332     ASSERT (u != 0, "maximum size exceeded");
00333     ASSERT (n == u, "power of three expected");
00334     if (n <= max ((nat) 9, (nat) Threshold(C,Th))) {
00335       nat l= aligned_size<C,V> (2*n2-1);
00336       C* temp= mmx_new<C> (l);
00337       Inner::mul (temp, dest, src, n2, n2);
00338       Vec::copy (dest    , temp         , n2);
00339       Vec::add  (dest    , temp + n2 + n, n - 1);
00340       Vec::sub  (dest    , temp + n2    , n);
00341       Vec::sub  (dest + n, temp + n2    , n);
00342       mmx_delete<C> (temp, l);
00343       return 0;
00344     }
00345     else {
00346       nat r = 0;
00347       nat k = log_3 (n);
00348       nat m = binpow (3, (k+1) >> 1), m2= m << 1, m3= m2 + m;
00349       nat t = n / m, t2= t << 1;
00350       nat eta= (t == m) ? 1 : 3;
00351       
00352       C** hh= bivariate_new<C> (m2, t2);
00353       C** gg= bivariate_new<C> (m2, t2);
00354       C** h1= bivariate_new<C> (m2, t);
00355       C** h2= bivariate_new<C> (m2, t);
00356       C** g = bivariate_new<C> (m2, t);
00357       bivariate_encode (hh, src , n, m, t);
00358       bivariate_encode (gg, dest, n, m, t);
00359       for (nat j = 1; j <= 2; j++) {
00360         C** h= j == 1 ? h1 : h2;
00361         bivariate_mod (h, (const C**) hh, (j*eta*t) % m3, m, t);
00362         bivariate_mod (g, (const C**) gg, (j*eta*t) % m3, m, t);
00363         variable_dilate (h, (j*eta) % m3, m, t);
00364         variable_dilate (g, (j*eta) % m3, m, t);
00365         direct_transform (h, m, t);
00366         direct_transform (g, m, t);
00367         for (nat i = 0; i < t; i++)
00368           r = mul_triadic (h[i], g[i], m, shift);
00369         inverse_transform (h, m, t, shift);
00370         variable_dilate (h, m3 - ((j*eta) % m3), m, t);
00371       }
00372       bivariate_delete<C> (g , m2 , t);
00373       bivariate_delete<C> (gg, m2, t2);
00374       bivariate_crt (hh, (const C**) h1, (const C**) h2, (eta*t) % m3, m, t, shift);
00375       bivariate_delete<C> (h1, m2, t);
00376       bivariate_delete<C> (h2, m2, t); 
00377       bivariate_decode (dest, (const C**) hh, n, m, t);
00378       bivariate_delete<C> (hh, m2, t2);
00379       return (shift) ? 0 : r + 1 + (k >> 1); }}
00380 
00381   TMPL static inline nat
00382   mul (C* dest, const C* s1, const C* s2, nat n1, nat n2, bool shift=true) {
00383     nat len = n1 + n2 - 1;
00384     nat n = next_power_of_three ((len + 1) >> 1);
00385     nat nn= n << 1;
00386     nat l = aligned_size<C,V> (nn);
00387     C* t1 = mmx_new<C> (l);
00388     C* t2 = mmx_new<C> (l);
00389     Vec::copy (t1, s1, n1); Vec::clear (t1 + n1, nn - n1);
00390     Vec::copy (t2, s2, n2); Vec::clear (t2 + n2, nn - n2);
00391     nat k = mul_triadic (t1, t2, n, shift);
00392     Vec::copy (dest, t1, len);
00393     mmx_delete<C> (t1, l);
00394     mmx_delete<C> (t2, l);
00395     return k; }
00396 
00397 }; 
00398 
00399 #undef TMPL
00400 } 
00401 #endif //__MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP