00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 #ifndef __MMX__POLYNOMIAL_TFT__HPP
00014 #define __MMX__POLYNOMIAL_TFT__HPP
00015 #include <algebramix/polynomial_dicho.hpp>
00016 #include <algebramix/polynomial_ring_dicho.hpp>
00017 #include <algebramix/fft_blocks.hpp>
00018 #include <algebramix/fft_simd.hpp>
00019 #include <algebramix/fft_truncated.hpp>
00020 
00021 namespace mmx {
00022 #define TMPL template<typename C>
00023 
00024 
00025 
00026 
00027 
00028 struct tft_threshold {};
00029 template<typename V> struct polynomial_balanced_tft;
00030 
00031 template<typename C>
00032 struct threshold_helper<C,tft_threshold> {
00033   typedef fixed_value<nat,128> impl;
00034 };
00035 
00036 template<typename C, typename V>
00037 struct threshold_helper<polynomial<C,V>,tft_threshold> {
00038   typedef fixed_value<nat,32> impl;
00039 };
00040 
00041 template<typename C, typename U, typename V, typename W>
00042 struct threshold_helper<modular<modulus<polynomial<C,U>,V>,W>,
00043                         tft_threshold> {
00044 typedef fixed_value<nat,32> impl;
00045 };
00046 
00047 template<typename V, typename T= tft_threshold>
00048 struct polynomial_tft_inc: public V {
00049   typedef typename V::Vec Vec;
00050   typedef typename V::Naive Naive;
00051   typedef typename V::Positive Positive;
00052   typedef polynomial_tft_inc<typename V::No_simd,T> No_simd;
00053   typedef polynomial_tft_inc<typename V::No_thread,T> No_thread;
00054   typedef polynomial_tft_inc<typename V::No_scaled,T> No_scaled;
00055 };
00056 
00057 template<typename F, typename V, typename W, typename Th>
00058 struct implementation<F,V,polynomial_tft_inc<W,Th> >:
00059   public implementation<F,V,W> {};
00060 
00061 DEFINE_VARIANT_1 (typename V, V, polynomial_tft,
00062                   polynomial_dicho<
00063                    polynomial_balanced_tft<
00064                      polynomial_tft_inc<
00065                        polynomial_karatsuba<V> > > >)
00066 
00067 template<typename MoV>
00068 struct polynomial_variant_helper<modular<modulus<nat,MoV>,
00069                                          modular_fixed<nat,0x0c0000001> > > {
00070   typedef polynomial_tft<polynomial_naive> PV;
00071 };
00072 
00073 template<typename C, typename V>
00074 struct polynomial_variant_helper<polynomial<C,polynomial_tft<V> > > {
00075   typedef polynomial_gcd_ring_dicho<polynomial_tft<V> > PV;
00076 };
00077 
00078 
00079 
00080 
00081 
00082 template<typename V, typename W, typename Th>
00083 struct implementation<polynomial_multiply,V,polynomial_tft_inc<W,Th> >:
00084   public implementation<polynomial_linear,V>
00085 {
00086   typedef implementation<vector_linear,V> Vec;
00087   typedef implementation<polynomial_linear,V> Pol;
00088   typedef implementation<polynomial_multiply,W> Fallback;
00089 
00090 TMPL static inline void
00091 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00092   typedef fft_blocks_transformer<C, fft_simd_transformer<C>,
00093     8, 5, 10, 16> FFTer;
00094   typedef fft_truncated_transformer<C,FFTer> Tfter;
00095   
00096   
00097   if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th))
00098     Fallback::mul (dest, s1, s2, n1, n2);
00099   else {
00100     format<C> fm= get_format (s1[0]);
00101     nat n= n1 + n2 - 1;
00102     nat m= next_power_of_two (n);
00103     nat l= aligned_size<C,V> (m);
00104     C* temp0= mmx_formatted_new<C> (l, fm);
00105     C* temp1= mmx_formatted_new<C> (l, fm);
00106     C* tempd= mmx_formatted_new<C> (l, fm);
00107     Tfter tfter (n, fm);
00108     Vec::copy (temp0, s1, n1);
00109     Vec::clear (temp0+n1, n-n1);
00110     Vec::copy (temp1, s2, n2);
00111     Vec::clear (temp1+n2, n-n2);
00112     tfter.direct_transform (temp0);
00113     tfter.direct_transform (temp1);
00114     Vec::mul (tempd, temp0, temp1, n);
00115     tfter.inverse_transform (tempd);
00116     Vec::copy (dest, tempd, n);
00117     mmx_delete<C> (temp0, l);
00118     mmx_delete<C> (temp1, l);
00119     mmx_delete<C> (tempd, l);
00120   }
00121 }
00122 
00123 TMPL static inline void
00124 square (C* dest, const C* s1, nat n1) {
00125   typedef fft_truncated_transformer<C> Tfter;
00126   if (n1 < Threshold(C,Th))
00127     Fallback::square (dest, s1, n1);
00128   else {
00129     format<C> fm= get_format (s1[0]);
00130     nat n= 2 * n1 - 1;
00131     nat m= next_power_of_two (n);
00132     nat l= aligned_size<C,V> (m);
00133     C* temp0= mmx_formatted_new<C> (l, fm);
00134     C* tempd= mmx_formatted_new<C> (l, fm);
00135     Tfter tfter (n, fm);
00136     Vec::copy (temp0, s1, n1);
00137     Vec::clear (temp0+n1, n-n1);
00138     tfter.direct_transform (temp0);
00139     Vec::mul (tempd, temp0, temp0, n);
00140     tfter.inverse_transform (tempd);
00141     Vec::copy (dest, tempd, n);
00142     mmx_delete<C> (temp0, l);
00143     mmx_delete<C> (tempd, l);
00144   }
00145 }
00146 
00147 }; 
00148 
00149 
00150 
00151 
00152 
00153 template<typename V>
00154 struct polynomial_balanced_tft: public V {
00155   typedef typename V::Vec Vec;
00156   typedef typename V::Naive Naive;
00157   typedef polynomial_balanced_tft<typename V::Positive> Positive;
00158   typedef polynomial_balanced_tft<typename V::No_simd> No_simd;
00159   typedef polynomial_balanced_tft<typename V::No_thread> No_thread;
00160   typedef polynomial_balanced_tft<typename V::No_scaled> No_scaled;
00161 };
00162 
00163 template<typename F, typename V, typename W>
00164 struct implementation<F,V,polynomial_balanced_tft<W> >:
00165   public implementation<F,V,W> {};
00166 
00167 template<typename V, typename W>
00168 struct implementation<polynomial_multiply,V,polynomial_balanced_tft<W> >:
00169   public implementation<polynomial_linear,V>
00170 {
00171   typedef implementation<vector_linear,V> Vec;
00172   typedef implementation<polynomial_linear,W> Pol;
00173   typedef implementation<polynomial_multiply,W> Fallback;
00174 
00175 TMPL static inline void
00176 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00177   if (n1 == 0 && n2 == 0)
00178     return;
00179   if (n1 == 0 || n2 == 0) {
00180     Pol::clear (dest, n1+n2-1);
00181     return;
00182   }
00183   if (n2 < n1) {
00184     mul (dest, s2, s1, n2, n1);
00185     return;
00186   }
00187   nat n= n1+n2-1;
00188   nat m= next_power_of_two (2*n1 - 1);
00189   if (n1 <= n2 && n >= m) {
00190     Pol::clear (dest, n);
00191     nat l= aligned_size<C,V> (m);
00192     C* temp= mmx_new<C> (l);
00193     while (n1+n2-1 >= m) {
00194       Fallback::mul (temp, s1, s2, n1, m-n1+1);
00195       Vec::add (dest, temp, m);
00196       n2 -= m-n1+1; s2 += m-n1+1; dest += m-n1+1;
00197     }
00198     if (n2 > 0) {
00199       mul (temp, s1, s2, n1, n2);
00200       Vec::add (dest, temp, n1+n2-1);
00201     }
00202     mmx_delete<C> (temp, l);
00203     return;
00204   }
00205   Fallback::mul (dest, s1, s2, n1, n2);
00206 }
00207   
00208 TMPL static inline void
00209 square (C* dest, const C* s1, nat n1) {
00210   Fallback::square (dest, s1, n1);
00211 }
00212 
00213 }; 
00214 
00215 #undef TMPL
00216 } 
00217 #endif //__MMX__POLYNOMIAL_TFT__HPP