00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__POLYNOMIAL_TFT__HPP
00014 #define __MMX__POLYNOMIAL_TFT__HPP
00015 #include <algebramix/polynomial_dicho.hpp>
00016 #include <algebramix/polynomial_ring_dicho.hpp>
00017 #include <algebramix/fft_blocks.hpp>
00018 #include <algebramix/fft_simd.hpp>
00019 #include <algebramix/fft_truncated.hpp>
00020
00021 namespace mmx {
00022 #define TMPL template<typename C>
00023
00024
00025
00026
00027
00028 struct tft_threshold {};
00029 template<typename V> struct polynomial_balanced_tft;
00030
00031 template<typename C>
00032 struct threshold_helper<C,tft_threshold> {
00033 typedef fixed_value<nat,128> impl;
00034 };
00035
00036 template<typename C, typename V>
00037 struct threshold_helper<polynomial<C,V>,tft_threshold> {
00038 typedef fixed_value<nat,32> impl;
00039 };
00040
00041 template<typename C, typename U, typename V, typename W>
00042 struct threshold_helper<modular<modulus<polynomial<C,U>,V>,W>,
00043 tft_threshold> {
00044 typedef fixed_value<nat,32> impl;
00045 };
00046
00047 template<typename V, typename T= tft_threshold>
00048 struct polynomial_tft_inc: public V {
00049 typedef typename V::Vec Vec;
00050 typedef typename V::Naive Naive;
00051 typedef typename V::Positive Positive;
00052 typedef polynomial_tft_inc<typename V::No_simd,T> No_simd;
00053 typedef polynomial_tft_inc<typename V::No_thread,T> No_thread;
00054 typedef polynomial_tft_inc<typename V::No_scaled,T> No_scaled;
00055 };
00056
00057 template<typename F, typename V, typename W, typename Th>
00058 struct implementation<F,V,polynomial_tft_inc<W,Th> >:
00059 public implementation<F,V,W> {};
00060
00061 DEFINE_VARIANT_1 (typename V, V, polynomial_tft,
00062 polynomial_dicho<
00063 polynomial_balanced_tft<
00064 polynomial_tft_inc<
00065 polynomial_karatsuba<V> > > >)
00066
00067 STMPL
00068 struct polynomial_variant_helper<modular<modulus<nat>,
00069 modular_fixed<nat,0x0c0000001> > > {
00070 typedef polynomial_tft<polynomial_naive> PV;
00071 };
00072
00073 template<typename C, typename V>
00074 struct polynomial_variant_helper<polynomial<C,polynomial_tft<V> > > {
00075 typedef polynomial_gcd_ring_dicho<polynomial_tft<V> > PV;
00076 };
00077
00078
00079
00080
00081
00082 template<typename V, typename W, typename Th>
00083 struct implementation<polynomial_multiply,V,polynomial_tft_inc<W,Th> >:
00084 public implementation<polynomial_linear,V>
00085 {
00086 typedef implementation<vector_linear,V> Vec;
00087 typedef implementation<polynomial_linear,V> Pol;
00088 typedef implementation<polynomial_multiply,W> Fallback;
00089
00090 TMPL static inline void
00091 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00092 typedef fft_blocks_transformer<C, fft_simd_transformer<C>,
00093 8, 5, 10, 16> FFTer;
00094 typedef fft_truncated_transformer<C,FFTer> Tfter;
00095
00096
00097 if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th))
00098 Fallback::mul (dest, s1, s2, n1, n2);
00099 else {
00100 format<C> fm= get_format (s1[0]);
00101 nat n= n1 + n2 - 1;
00102 nat m= next_power_of_two (n);
00103 nat l= aligned_size<C,V> (m);
00104 C* temp0= mmx_formatted_new<C> (l, fm);
00105 C* temp1= mmx_formatted_new<C> (l, fm);
00106 C* tempd= mmx_formatted_new<C> (l, fm);
00107 Tfter tfter (n, fm);
00108 Vec::copy (temp0, s1, n1);
00109 Vec::clear (temp0+n1, n-n1);
00110 Vec::copy (temp1, s2, n2);
00111 Vec::clear (temp1+n2, n-n2);
00112 tfter.direct_transform (temp0);
00113 tfter.direct_transform (temp1);
00114 Vec::mul (tempd, temp0, temp1, n);
00115 tfter.inverse_transform (tempd);
00116 Vec::copy (dest, tempd, n);
00117 mmx_delete<C> (temp0, l);
00118 mmx_delete<C> (temp1, l);
00119 mmx_delete<C> (tempd, l);
00120 }
00121 }
00122
00123 TMPL static inline void
00124 square (C* dest, const C* s1, nat n1) {
00125 typedef fft_truncated_transformer<C> Tfter;
00126 if (n1 < Threshold(C,Th))
00127 Fallback::square (dest, s1, n1);
00128 else {
00129 format<C> fm= get_format (s1[0]);
00130 nat n= 2 * n1 - 1;
00131 nat m= next_power_of_two (n);
00132 nat l= aligned_size<C,V> (m);
00133 C* temp0= mmx_formatted_new<C> (l, fm);
00134 C* tempd= mmx_formatted_new<C> (l, fm);
00135 Tfter tfter (n, fm);
00136 Vec::copy (temp0, s1, n1);
00137 Vec::clear (temp0+n1, n-n1);
00138 tfter.direct_transform (temp0);
00139 Vec::mul (tempd, temp0, temp0, n);
00140 tfter.inverse_transform (tempd);
00141 Vec::copy (dest, tempd, n);
00142 mmx_delete<C> (temp0, l);
00143 mmx_delete<C> (tempd, l);
00144 }
00145 }
00146
00147 };
00148
00149
00150
00151
00152
00153 template<typename V>
00154 struct polynomial_balanced_tft: public V {
00155 typedef typename V::Vec Vec;
00156 typedef typename V::Naive Naive;
00157 typedef polynomial_balanced_tft<typename V::Positive> Positive;
00158 typedef polynomial_balanced_tft<typename V::No_simd> No_simd;
00159 typedef polynomial_balanced_tft<typename V::No_thread> No_thread;
00160 typedef polynomial_balanced_tft<typename V::No_scaled> No_scaled;
00161 };
00162
00163 template<typename F, typename V, typename W>
00164 struct implementation<F,V,polynomial_balanced_tft<W> >:
00165 public implementation<F,V,W> {};
00166
00167 template<typename V, typename W>
00168 struct implementation<polynomial_multiply,V,polynomial_balanced_tft<W> >:
00169 public implementation<polynomial_linear,V>
00170 {
00171 typedef implementation<vector_linear,V> Vec;
00172 typedef implementation<polynomial_linear,W> Pol;
00173 typedef implementation<polynomial_multiply,W> Fallback;
00174
00175 TMPL static inline void
00176 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00177 if (n1 == 0 && n2 == 0)
00178 return;
00179 if (n1 == 0 || n2 == 0) {
00180 Pol::clear (dest, n1+n2-1);
00181 return;
00182 }
00183 if (n2 < n1) {
00184 mul (dest, s2, s1, n2, n1);
00185 return;
00186 }
00187 nat n= n1+n2-1;
00188 nat m= next_power_of_two (2*n1 - 1);
00189 if (n1 <= n2 && n >= m) {
00190 Pol::clear (dest, n);
00191 nat l= aligned_size<C,V> (m);
00192 C* temp= mmx_new<C> (l);
00193 while (n1+n2-1 >= m) {
00194 Fallback::mul (temp, s1, s2, n1, m-n1+1);
00195 Vec::add (dest, temp, m);
00196 n2 -= m-n1+1; s2 += m-n1+1; dest += m-n1+1;
00197 }
00198 if (n2 > 0) {
00199 mul (temp, s1, s2, n1, n2);
00200 Vec::add (dest, temp, n1+n2-1);
00201 }
00202 mmx_delete<C> (temp, l);
00203 return;
00204 }
00205 Fallback::mul (dest, s1, s2, n1, n2);
00206 }
00207
00208 TMPL static inline void
00209 square (C* dest, const C* s1, nat n1) {
00210 Fallback::square (dest, s1, n1);
00211 }
00212
00213 };
00214
00215 #undef TMPL
00216 }
00217 #endif //__MMX__POLYNOMIAL_TFT__HPP