00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__POLYNOMIAL_DICHO__HPP
00014 #define __MMX__POLYNOMIAL_DICHO__HPP
00015 #include <basix/vector_sort.hpp>
00016 #include <algebramix/polynomial_naive.hpp>
00017 #include <algebramix/crt_polynomial.hpp>
00018
00019 namespace mmx {
00020 #define TMPL template<typename C>
00021 #define TMPLP template <typename Polynomial>
00022 #define Vector vector<C>
00023
00024
00025
00026
00027
00028 template<typename V>
00029 struct polynomial_karatsuba: public V {
00030 typedef typename V::Vec Vec;
00031 typedef typename V::Naive Naive;
00032 typedef polynomial_karatsuba<typename V::Positive> Positive;
00033 typedef polynomial_karatsuba<typename V::No_simd> No_simd;
00034 typedef polynomial_karatsuba<typename V::No_thread> No_thread;
00035 typedef polynomial_karatsuba<typename V::No_scaled> No_scaled;
00036 };
00037
00038 template<typename F, typename V, typename W>
00039 struct implementation<F,V,polynomial_karatsuba<W> >:
00040 public implementation<F,V,W> {};
00041
00042 template<typename V>
00043 struct polynomial_dicho: public V {
00044 typedef typename V::Vec Vec;
00045 typedef typename V::Naive Naive;
00046 typedef polynomial_dicho<typename V::Positive> Positive;
00047 typedef polynomial_dicho<typename V::No_simd> No_simd;
00048 typedef polynomial_dicho<typename V::No_thread> No_thread;
00049 typedef polynomial_dicho<typename V::No_scaled> No_scaled;
00050 };
00051
00052 template<typename F, typename V, typename W>
00053 struct implementation<F,V,polynomial_dicho<W> >:
00054 public implementation<F,V,W> {};
00055
00056
00057
00058
00059
00060 template<typename V>
00061 struct polynomial_multiply_threshold {};
00062
00063 template<typename V, typename W>
00064 struct implementation<polynomial_multiply,V,polynomial_karatsuba<W> >:
00065 public implementation<polynomial_linear,V>
00066 {
00067 typedef polynomial_multiply_threshold<polynomial_karatsuba<W> > Th;
00068 typedef implementation<vector_linear,V> Vec;
00069 typedef implementation<polynomial_linear,W> Pol;
00070 typedef implementation<polynomial_multiply,W> Fallback;
00071
00072 TMPL static void
00073 multiply (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00074 if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th))
00075 Fallback::mul (dest, s1, s2, n1, n2);
00076 else {
00077 nat p1= n1 >> 1, p2= n2 >> 1, P= p1+p2-1;
00078 nat spc= aligned_size<C,V> (3*(p1+p2));
00079 C* buf= mmx_new<C> (spc);
00080 C* low1= dest;
00081 C* low2= low1 + p1;
00082 C* mid1= buf;
00083 C* mid2= buf + p1;
00084 C* hi1 = mid2 + p2;
00085 C* hi2 = hi1 + p1;
00086 C* Low = mid1;
00087 C* Mid = hi1;
00088 C* Hi = Mid + p1 + p2;
00089 Vec::half_copy (low1 , s1 , p1);
00090 Vec::half_copy (hi1 , s1+1, p1);
00091 Pol::add (mid1 , low1, hi1 , p1);
00092 Vec::half_copy (low2 , s2 , p2);
00093 Vec::half_copy (hi2 , s2+1, p2);
00094 Pol::add (mid2 , low2, hi2 , p2);
00095 multiply (Hi , hi1 , hi2 , p1, p2);
00096 multiply (Mid , mid1, mid2, p1, p2);
00097 multiply (Low , low1, low2, p1, p2);
00098 Pol::sub (Mid , Low , P);
00099 Pol::sub (Mid , Hi , P);
00100 Pol::add (Low+1, Hi , P-1);
00101 Low[P]= Hi[P-1];
00102 Vec::double_copy (dest , Low, P+1);
00103 Vec::double_copy (dest+1, Mid, P);
00104 mmx_delete<C> (buf, spc);
00105
00106 if ((n1 & 1) != 0) {
00107 dest[(P<<1)+1]= C(0);
00108 Pol::mul_add (dest + (p1<<1), s2, s1[p1<<1], p2<<1);
00109 }
00110 if ((n2 & 1) != 0) {
00111 dest[n1+n2-2]= C(0);
00112 Pol::mul_add (dest + (p2<<1), s1, s2[p2<<1], n1);
00113 }
00114 }
00115 }
00116
00117 TMPL static inline void
00118 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00119 if (n1 == 0 && n2 == 0) return;
00120 if (n1 == 0 || n2 == 0)
00121 Pol::clear (dest, n1 + n2 - 1);
00122 else
00123 multiply (dest, s1, s2, n1, n2);
00124 }
00125
00126 TMPL static void
00127 square (C* dest, const C* s, nat n) {
00128 if (n == 0) return;
00129 if (n < Threshold(C,Th))
00130 Fallback::square (dest, s, n);
00131 else {
00132 nat p= n >> 1, P= 2*p-1, spc= aligned_size<C,V> (6*p);
00133 C* buf= mmx_new<C> (spc);
00134 C* low= dest;
00135 C* mid= buf;
00136 C* hi = buf + 2*p;
00137 C* Low= buf;
00138 C* Mid= hi;
00139 C* Hi = Mid + 2*p;
00140 Vec::half_copy (low , s , p);
00141 Vec::half_copy (hi , s+1, p);
00142 Pol::add (mid , low, hi, p);
00143 square (Hi , hi , p);
00144 square (Mid , mid, p);
00145 square (Low , low, p);
00146 Pol::sub (Mid , Low, P);
00147 Pol::sub (Mid , Hi , P);
00148 Pol::add (Low+1, Hi , P-1);
00149 Low[P]= Hi[P-1];
00150 Vec::double_copy (dest , Low, P+1);
00151 Vec::double_copy (dest+1, Mid, P);
00152 mmx_delete<C> (buf, spc);
00153
00154 if ((n & 1) != 0) {
00155 dest[(P<<1)+1]= dest[2*n-2]= C(0);
00156 Pol::mul_add (dest + (p<<1), s, s[p<<1], p<<1);
00157 Pol::mul_add (dest + (p<<1), s, s[p<<1], n);
00158 }
00159 }
00160 }
00161
00162 TMPL static void
00163 tmultiply (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00164 if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th))
00165 Fallback::tmul (dest, s1, s2, n1, n2);
00166 else {
00167 nat p1= n1 >> 1, p2= n2 >> 1, P= p1+p2-1;
00168 nat spc= aligned_size<C,V> (3*(p1+P)+1);
00169 C* buf= mmx_new<C> (spc);
00170 C* low1= buf;
00171 C* low2= low1 + p1;
00172 C* mid1= low2 + P+1;
00173 C* mid2= mid1 + p1;
00174 C* hi1 = mid2 + P;
00175 C* hi2 = hi1 + p1;
00176 C* Low = mid1;
00177 C* Mid = hi1;
00178 C* Hi = dest;
00179 Vec::half_copy (low1, s1 , p1);
00180 Vec::half_copy (hi1 , s1+1, p1);
00181 Pol::add (mid1, low1, hi1, p1);
00182 Vec::half_copy (low2, s2 , P+1);
00183 Vec::half_copy (hi2 , s2+1, P);
00184 Pol::add (mid2, low2+1, hi2, P);
00185 tmultiply (Hi , hi1 , hi2 , p1, p2);
00186 tmultiply (Mid, mid1, mid2, p1, p2);
00187 tmultiply (Low, low1, low2, p1, p2+1);
00188 Pol::sub (Mid, Hi , p2);
00189 Pol::sub (Mid, Low+1, p2);
00190 Pol::add (Low, Hi , p2);
00191 Vec::double_copy (dest , Low, p2);
00192 Vec::double_copy (dest+1, Mid, p2);
00193 mmx_delete<C> (buf, spc);
00194
00195 if ((n1 & 1) != 0)
00196 Pol::mul_add (dest, s2+n1-1, s1[n1-1], n2);
00197
00198 if ((n2 & 1) != 0)
00199 dest[n2-1]= Vec::inn_prod (s1, s2+n2-1, n1);
00200 }
00201 }
00202
00203 TMPL static inline void
00204 tmul (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00205
00206
00207 Pol::clear (dest, n2);
00208 if (n1 == 0 || n2 == 0) return;
00209 tmultiply (dest, s1, s2, n1, n2);
00210 }
00211
00212 };
00213
00214
00215
00216
00217
00218 template<typename V>
00219 struct polynomial_divide_threshold {};
00220
00221 template<typename V, typename BV>
00222 struct implementation<polynomial_divide,V,polynomial_dicho<BV> >:
00223 public implementation<polynomial_multiply,V>
00224 {
00225 typedef polynomial_divide_threshold<polynomial_dicho<BV> > Th;
00226 typedef implementation<polynomial_multiply,V> Pol;
00227 typedef implementation<polynomial_divide,V,BV> Fallback;
00228
00229 TMPL static void
00230 invert_lo (C* dest, const C* src, nat n) {
00231 if (n == 0) return;
00232 if (n == 1) *dest= C(1) / *src;
00233 else {
00234 nat h= (n+1) >> 1;
00235 nat l= n - h;
00236 invert_lo (dest, src, h);
00237 nat buf_size= aligned_size<C,V> (n << 1);
00238 C* buf= mmx_new<C> (buf_size);
00239 C* aux= buf + n;
00240 Pol::mul (buf, src, dest, n, h);
00241
00242 Pol::mul (aux, dest, buf + h, l, l);
00243
00244 Pol::neg (dest + h, aux, l);
00245 mmx_delete<C> (buf, buf_size);
00246 }
00247 }
00248
00249 TMPL static void
00250 invert_hi (C* dest, const C* src, nat n) {
00251 if (n == 1) *dest= C(1) / *src;
00252 else {
00253 nat h= (n+1) >> 1;
00254 nat l= n - h;
00255 invert_hi (dest + l, src + l, h);
00256 nat buf_size= aligned_size<C,V> (n << 1);
00257 C* buf= mmx_new<C> (buf_size);
00258 C* aux= buf + l;
00259 Pol::mul (aux, src, dest + l, n, h);
00260
00261 Pol::mul (buf, dest + h, aux + h - 1, l, l);
00262
00263 Pol::neg (dest, buf + l - 1, l);
00264 mmx_delete<C> (buf, buf_size);
00265 }
00266 }
00267
00268 TMPL static void
00269 quo_rem (C* dest, C* s1, const C* s2, nat n1, nat n2) {
00270
00271 if (n1 < n2);
00272 else if (n1 < Threshold(C,Th))
00273 Fallback::quo_rem (dest, s1, s2, n1, n2);
00274 else {
00275 nat tot= aligned_size<C,V> ((n2 << 1) + n2);
00276 C* buf= mmx_new<C> (tot);
00277 C* inv= buf + (n2 << 1);
00278 nat nq = n1 + 1 - n2;
00279 nat l = min (n2, nq);
00280 invert_hi (inv + n2 - l, s2 + n2 - l, l);
00281 while (n1 >= n2) {
00282 nat nq= n1 + 1 - n2;
00283 nat l = min (n2, nq);
00284 C* dest_hi= dest + nq - l;
00285 Pol::mul (buf, s1 + n1 - l, inv + n2 - l, l, l);
00286 Pol::copy (dest_hi, buf + l - 1, l);
00287 Pol::mul (buf, dest_hi, s2, l, n2);
00288 Pol::sub (s1 + n1 - (n2 + l - 1), buf, n2 + l - 1);
00289 n1 -= l;
00290 }
00291 mmx_delete<C> (buf, tot);
00292 }
00293 }
00294
00295 TMPL static void
00296 tquo_rem (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00297
00298
00299
00300
00301
00302 if (n1 < n2)
00303 Pol::copy (dest, s1, n1);
00304 else if (n1 < Threshold(C,Th))
00305 Fallback::tquo_rem (dest, s1, s2, n1, n2);
00306 else {
00307 nat tot= aligned_size<C,V> (3 * n2);
00308 C* buf= mmx_new<C> (tot);
00309 C* inv= buf + 2*n2;
00310 nat nq = n1 + 1 - n2;
00311 nat l = min (n2, nq);
00312 invert_hi (inv + n2 - l, s2 + n2 - l, l);
00313 Pol::neg (inv + n2 - l, inv + n2 - l, l);
00314 Pol::copy (dest, s1, n2-1);
00315 Pol::clear (dest+n2-1, n1-n2+1);
00316 nat m = n2-1;
00317 while (m < n1) {
00318 nat nq= n1 - m;
00319 nat l = min (n2, nq);
00320 Pol::clear (buf , l-1);
00321 Pol::tmul (buf+l-1 , s2 , dest + m - (n2 - 1), n2, l);
00322 Pol::sub (buf+l-1 , s1 + m , l);
00323 Pol::tmul (dest + m, inv + n2 - l, buf, l , l);
00324 m += l;
00325 }
00326 mmx_delete<C> (buf, tot);
00327 }
00328 }
00329
00330
00331 TMPL static void
00332 pinvert_hi (C* dest, const C* src, nat n) {
00333 if (n == 1) *dest= C(1);
00334 else {
00335 nat h= (n+1) >> 1;
00336 nat l= n - h;
00337 nat tmp_size= aligned_size<C,V> (l);
00338 C* tmp= mmx_new<C> (tmp_size);
00339 pinvert_hi (dest + l, src + l, h);
00340 pinvert_hi (tmp, src + h, l);
00341 nat buf_size= aligned_size<C,V> (n << 1);
00342 C* buf= mmx_new<C> (buf_size);
00343 C* aux= buf + l;
00344 Pol::mul (aux, src, dest + l, n, h);
00345
00346 Pol::mul_sc (dest + l, binpow (src[n-1], l), h);
00347 Pol::mul (buf, tmp, aux + h - 1, l, l);
00348
00349 Pol::neg (dest, buf + l - 1, l);
00350 mmx_delete<C> (buf, buf_size);
00351 mmx_delete<C> (tmp, tmp_size);
00352 }
00353 }
00354
00355 TMPL static void
00356 pquo_rem (C* dest, C* s1, const C* s2, nat n1, nat n2) {
00357 if (n1 < n2);
00358 else if (n1 < Threshold(C,Th))
00359 Fallback::pquo_rem (dest, s1, s2, n1, n2);
00360 else {
00361 nat tot= aligned_size<C,V> ((n2 << 1) + n2);
00362 C* buf= mmx_new<C> (tot);
00363 C* inv= buf + (n2 << 1);
00364 nat tmp_size= aligned_size<C,V> (n2);
00365 C* tmp= mmx_new<C> (tmp_size);
00366 nat nq= n1 + 1 - n2;
00367 nat l = min (n2, nq);
00368 nat l_end= nq % l;
00369 pinvert_hi (inv + n2 - l, s2 + n2 - l, l);
00370 if (l_end != 0)
00371 pinvert_hi (tmp + n2 - l_end, s2 + n2 - l_end, l_end);
00372 while (n1 >= n2) {
00373 nat nq= n1 + 1 - n2;
00374 nat l = min (n2, nq);
00375 C* dest_hi= dest + nq - l;
00376 if (l == l_end)
00377 Pol::mul (buf, s1 + n1 - l, tmp + n2 - l, l, l);
00378 else
00379 Pol::mul (buf, s1 + n1 - l, inv + n2 - l, l, l);
00380 Pol::copy (dest_hi, buf + l - 1, l);
00381 Pol::mul (buf, dest_hi, s2, l, n2);
00382 Pol::mul_sc (s1, binpow (s2[n2-1], l), n1);
00383 Pol::sub (s1 + n1 - (n2 + l - 1), buf, n2 + l - 1);
00384 Pol::mul_sc (dest_hi, binpow (s2[n2-1], nq - l), l);
00385 n1 -= l;
00386 }
00387 mmx_delete<C> (buf, tot);
00388 mmx_delete<C> (tmp, tmp_size);
00389 }
00390 }
00391
00392 };
00393
00394
00395
00396
00397
00398
00399
00400 template<typename V>
00401 struct polynomial_euclidean_threshold {};
00402
00403 template<typename V, typename BV>
00404 struct implementation<polynomial_euclidean,V,polynomial_dicho<BV> >:
00405 public implementation<polynomial_divide,V>
00406 {
00407 typedef polynomial_euclidean_threshold<polynomial_dicho<BV> > Th;
00408 typedef implementation<vector_linear,V> Vec;
00409 typedef implementation<polynomial_divide,V> Pol;
00410 typedef implementation<polynomial_euclidean,V,BV> Fallback;
00411
00412 private:
00413
00414 TMPL static void
00415 dot_product (C* d, nat& nd,
00416 const C* r0, const C* r1, nat nr0, nat nr1,
00417 const C* s0, const C* s1, nat ns0, nat ns1,
00418 C* t) {
00419
00420 if ((nr0 == 0 || ns0 == 0) && (nr1 == 0 || ns1 == 0)) {
00421 nd = 0; return;
00422 }
00423 if (nr0 == 0 || ns0 == 0) {
00424 nd = nr1 + ns1 - 1;
00425 Pol::mul (d, r1, s1, nr1, ns1);
00426 Pol::trim (d, nd);
00427 return;
00428 }
00429 if (nr1 == 0 || ns1 == 0) {
00430 nd = nr0 + ns0 - 1;
00431 Pol::mul (d, r0, s0, nr0, ns0);
00432 Pol::trim (d, nd);
00433 }
00434 if (nr0 + ns0 - 1 < nr1 + ns1 - 1) {
00435 nd = nr1 + ns1 - 1;
00436 Pol::mul (d, r1, s1, nr1, ns1);
00437 Pol::mul (t, r0, s0, nr0, ns0);
00438 Pol::add (d, t , nr0 + ns0 - 1);
00439 }
00440 else {
00441 nd = nr0 + ns0 - 1;
00442 Pol::mul (d, r0, s0, nr0, ns0);
00443 Pol::mul (t, r1, s1, nr1, ns1);
00444 Pol::add (d, t , nr1 + ns1 - 1);
00445 }
00446 Pol::trim (d, nd);
00447 }
00448
00449 TMPL static void
00450 matrix_vector_product (C* r0, C* r1, nat& nr0, nat& nr1,
00451 const C* R00, const C* R01, const C* R10, const C* R11,
00452 nat nR00, nat nR01, nat nR10, nat nR11,
00453 const C* s0, const C* s1, nat n0, nat n1,
00454 C* tp) {
00455
00456 dot_product (r0, nr0, R00, R01, nR00, nR01, s0, s1, n0, n1, tp);
00457 dot_product (r1, nr1, R10, R11, nR10, nR11, s0, s1, n0, n1, tp);
00458 }
00459
00460 TMPL static void
00461 matrix_product (C* Q00, C* Q01, C* Q10, C* Q11,
00462 nat& nQ00, nat& nQ01, nat& nQ10, nat& nQ11,
00463 const C* S00, const C* S01, const C* S10, const C* S11,
00464 nat nS00, nat nS01, nat nS10, nat nS11,
00465 const C* R00, const C* R01, const C* R10, const C* R11,
00466 nat nR00, nat nR01, nat nR10, nat nR11, C* tp) {
00467
00468 matrix_vector_product (Q00, Q10, nQ00, nQ10,
00469 S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00470 R00, R10, nR00, nR10, tp);
00471 matrix_vector_product (Q01, Q11, nQ01, nQ11,
00472 S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00473 R01, R11, nR01, nR11, tp);
00474 }
00475
00476 TMPL static void
00477 new_matrix (C*& M00, C*& M01, C*& M10, C*& M11, nat l) {
00478 M00= mmx_new<C> (l); M01= mmx_new<C> (l);
00479 M10= mmx_new<C> (l); M11= mmx_new<C> (l);
00480 }
00481
00482 TMPL static void
00483 delete_matrix (C* M00, C* M01, C* M10, C* M11, nat l) {
00484 mmx_delete<C> (M00, l); mmx_delete<C> (M01, l);
00485 mmx_delete<C> (M10, l); mmx_delete<C> (M11, l);
00486 }
00487
00488 TMPL static void
00489 half_gcd (C* Q00, C* Q01, C* Q10, C* Q11,
00490 nat& nQ00, nat& nQ01, nat& nQ10, nat& nQ11,
00491 const C* r0, const C* r1, nat n0, nat n1, nat k,
00492 C* rho, C* tp) {
00493
00494
00495
00496
00497
00498 VERIFY (n0 >= n1, "bad input sizes");
00499 VERIFY (k <= n0, "index k out of range");
00500 if (n1 == 0 || k < n0 - n1 + 1) {
00501 Q00[0] = 1; nQ00 = 1; nQ01 = 0;
00502 Q11[0] = 1; nQ10 = 0; nQ11 = 1;
00503 return;
00504 }
00505 if (k == 1) {
00506 Q00[0] = 1; nQ00 = 1; nQ01 = 0;
00507 Q10[0] = 1; Q11[0] = - r0[n0-1] / r1[n1-1]; nQ10 = 1; nQ11 = 1;
00508 return;
00509 }
00510 nat h = k >> 1, h2 = (h << 1) - 1;
00511 nat len_R= aligned_size<C,V> (h2);
00512 nat nR00, nR01, nR10, nR11;
00513 C* R00, * R01, * R10, * R11;
00514 new_matrix (R00, R01, R10, R11, len_R);
00515 if (h2 < n0 - n1)
00516 half_gcd (R00, R01, R10, R11, nR00, nR01, nR10, nR11,
00517 r0, r1, n0, n1, h, rho, tp);
00518 else
00519 half_gcd (R00, R01, R10, R11, nR00, nR01, nR10, nR11,
00520 r0 + n0 - h2, r1 + n0 - h2,
00521 h2, h2 - (n0 - n1), h,
00522 rho == NULL ? rho : (rho + (n0 - h2)), tp);
00523 nat len_r= aligned_size<C,V> (n0 + h);
00524 C* rjm1= mmx_new<C> (len_r), * rj = mmx_new<C> (len_r);
00525 nat nj, njm1;
00526
00527 matrix_vector_product (rjm1, rj, njm1, nj,
00528 R00, R01, R10, R11, nR00, nR01, nR10, nR11,
00529 r0, r1, n0, n1, tp);
00530 if (nj == 0 || k < n0 - nj + 1) {
00531 Pol::copy (Q00, R00, nR00); nQ00= nR00;
00532 Pol::copy (Q01, R01, nR01); nQ01= nR01;
00533 Pol::copy (Q10, R10, nR10); nQ10= nR10;
00534 Pol::copy (Q11, R11, nR11); nQ11= nR11;
00535 return;
00536 }
00537 nat len_S= aligned_size<C,V> (max (k - (n0 - nj), njm1 - nj + 1));
00538 nat nS00, nS01, nS10, nS11;
00539 C* S00, * S01, * S10, * S11;
00540 new_matrix (S00, S01, S10, S11, len_S);
00541
00542 nat len_T= aligned_size<C,V> (njm1 - nj + h);
00543 nat nT00, nT01, nT10, nT11;
00544 C* T00, * T01, * T10, * T11;
00545 new_matrix (T00, T01, T10, T11, len_T);
00546
00547 S01[0]= 1; S10[0]= 1;
00548 nS00= 0; nS01= 1; nS10= 1; nS11= njm1 - nj + 1;
00549 Pol::quo_rem (S11, rjm1, rj, njm1, nj);
00550 Vec::neg (S11, nS11);
00551 Pol::trim (rj, nj); Pol::trim (rjm1, njm1);
00552
00553 if (njm1 != 0) {
00554 C c= 1 / rjm1[njm1-1];
00555 if (rho != NULL) rho[njm1-1] = rjm1[njm1-1];
00556 Pol::mul_sc (rjm1, c, njm1);
00557 Pol::mul_sc (S11, c, nS11);
00558 Pol::mul_sc (S10, c, nS10);
00559 }
00560 matrix_product (T00, T01, T10, T11, nT00, nT01, nT10, nT11,
00561 S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00562 R00, R01, R10, R11, nR00, nR01, nR10, nR11, tp);
00563
00564 if (njm1 == 0 || k < n0 - njm1 + 1) {
00565 Pol::copy (Q00, T00, nT00); nQ00= nT00;
00566 Pol::copy (Q01, T01, nT01); nQ01= nT01;
00567 Pol::copy (Q10, T10, nT10); nQ10= nT10;
00568 Pol::copy (Q11, T11, nT11); nQ11= nT11;
00569 return;
00570 }
00571 h = k - (n0 - nj); h2 = (h << 1) - 1;
00572 if (h2 < nj - njm1 || h2 > nj)
00573 half_gcd (S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00574 rj, rjm1, nj, njm1, h, rho, tp);
00575 else
00576 half_gcd (S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00577 rj + nj - h2, rjm1 + nj - h2, h2, h2 - (nj - njm1), h,
00578 rho == NULL ? rho : (rho + (nj - h2)), tp);
00579 mmx_delete<C> (rjm1, len_r); mmx_delete<C> (rj, len_r);
00580 matrix_product (Q00, Q01, Q10, Q11, nQ00, nQ01, nQ10, nQ11,
00581 S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00582 T00, T01, T10, T11, nT00, nT01, nT10, nT11, tp);
00583 delete_matrix (R00, R01, R10, R11, len_R);
00584 delete_matrix (S00, S01, S10, S11, len_S);
00585 delete_matrix (T00, T01, T10, T11, len_T);
00586 }
00587
00588 public:
00589
00590 TMPL static inline void
00591 euclidean_sequence (const C* s1, const C* s2, nat n1, nat n2,
00592 C* d1, C* d2, nat& m1 , nat& m2,
00593 C* u1, C* u2, nat& nu1, nat& nu2,
00594 C* v1, C* v2, nat& nv1, nat& nv2,
00595 nat* n, C* rho, C* q, C** r, C** co1, C** co2, nat k= 0) {
00596 Fallback::euclidean_sequence (s1, s2, n1, n2, d1, d2, m1, m2,
00597 u1, u2, nu1, nu2, v1, v2, nv1, nv2,
00598 n, rho, q, r, co1, co2, k);
00599 }
00600
00601 TMPL static void
00602 gcd (C* g, nat& n, const C* s1, const C* s2, nat n1, nat n2,
00603 C* uu1, C* uu2, nat& nuu1, nat& nuu2) {
00604
00605 if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th)) {
00606 Fallback::gcd (g, n, s1, s2, n1, n2, uu1, uu2, nuu1, nuu2); return; }
00607 VERIFY (n1>0 && n2>0 && s1[n1-1] != 0 && s2[n2-1] != 0,
00608 "invalid hypothesis for gcd computation");
00609 nat nu1, nu2, nv1, nv2;
00610 C c1= 1 / s1[n1-1], c2= 1 / s2[n2-1];
00611 nat l1= aligned_size<C,V> (n1), l2= aligned_size<C,V> (n2);
00612 C* z1= mmx_new<C> (l1), * z2= mmx_new<C> (l2);
00613 Pol::mul_sc (z1, s1, c1, n1); Pol::mul_sc (z2, s2, c2, n2);
00614 C* u1= mmx_new<C> (l2), * u2= mmx_new<C> (l1);
00615 C* v1= mmx_new<C> (l2), * v2= mmx_new<C> (l1);
00616 nat len_tp= aligned_size<C,V> (n1 + n2);
00617 C* tp1= mmx_new<C> (len_tp), * tp2= mmx_new<C> (len_tp);
00618 C* rho= NULL;
00619 if (n1 >= n2)
00620 half_gcd (u1, u2, v1, v2, nu1, nu2, nv1, nv2,
00621 z1, z2, n1, n2, n1, rho, tp1);
00622 else
00623 half_gcd (u2, u1, v2, v1, nu2, nu1, nv2, nv1,
00624 z2, z1, n2, n1, n2, rho, tp1);
00625 VERIFY (nu1 <= n2 && nv1 <= n2 && nu2 <= n1 && nv2 <= n1, "bug");
00626 dot_product (tp2, n, u1, u2, nu1, nu2, z1, z2, n1, n2, tp1);
00627 VERIFY (n <= min (n1, n2), "bug");
00628 C c= n == 0 ? 0 : (1 / tp2[n-1]);
00629 Pol::mul_sc (g, tp2, c, n); Pol::clear (g + n, min (n1, n2) - n);
00630 if (uu1 != NULL) {
00631 Pol::mul_sc (uu1, u1, c1 * c, nu1);
00632 Pol::clear (uu1 + nu1, n2 - nu1); nuu1= nu1; }
00633 if (uu2 != NULL) {
00634 Pol::mul_sc (uu2, u2, c2 * c, nu2);
00635 Pol::clear (uu2 + nu2, n1 - nu2); nuu2= nu2; }
00636 mmx_delete<C> (tp1, len_tp); mmx_delete<C> (tp2, len_tp);
00637 mmx_delete<C> (u1, l2); mmx_delete<C> (u2, l1);
00638 mmx_delete<C> (v1, l2); mmx_delete<C> (v2, l1);
00639 }
00640
00641 TMPL static void
00642 gcd (C* g, nat& n, const C* s1, const C* s2, nat n1, nat n2) {
00643 nat nuu1, nuu2;
00644 C* uu1= NULL, * uu2= NULL;
00645 gcd (g, n, s1, s2, n1, n2, uu1, uu2, nuu1, nuu2);
00646 }
00647
00648 TMPL static void
00649 gcd (C* g, nat& n, const C* s1, const C* s2, nat n1, nat n2,
00650 C* uu1, nat& nuu1) {
00651 C* uu2= NULL; nat nuu2;
00652 gcd (g, n, s1, s2, n1, n2, uu1, uu2, nuu1, nuu2);
00653 }
00654
00655 TMPL static void
00656 reconstruct (C* r, C* t, const C* s, nat m, const C* p, nat n, nat k) {
00657
00658
00659
00660 if (n < Threshold(C,Th) || k == n || k == 0) {
00661 Fallback::reconstruct (r, t, s, m, p, n, k); return; }
00662 VERIFY (n > k && n > 0 && m > 0 && m <= n && s[m-1] != 0,
00663 "invalid hypothesis for gcd computation");
00664 nat nu1, nu2, nv1, nv2;
00665 nat n1= n+1, n2= m;
00666 nat l1= aligned_size<C,V> (n1), l2= aligned_size<C,V> (n2);
00667 C* u1= mmx_new<C> (l2), * u2= mmx_new<C> (l1);
00668 C* v1= mmx_new<C> (l2), * v2= mmx_new<C> (l1);
00669 C* s1= mmx_new<C> (l1); Pol::copy (s1, p, n1); const C* s2= s;
00670 nat len_tp= aligned_size<C,V> (n1 + n2);
00671 C* tp= mmx_new<C> (len_tp);
00672 C* rho= NULL;
00673 half_gcd (u1, u2, v1, v2, nu1, nu2, nv1, nv2,
00674 s1, s2, n1, n2, n-k, rho, tp);
00675
00676 nat len_r= aligned_size<C,V> (2*n1-k);
00677 C* rjm1= mmx_new<C> (len_r), * rj = mmx_new<C> (len_r);
00678 nat nj, njm1;
00679 matrix_vector_product (rjm1, rj, njm1, nj,
00680 u1, u2, v1, v2, nu1, nu2, nv1, nv2,
00681 s1, s2, n1, n2, tp);
00682 if (nj <= k) {
00683 VERIFY (nv2 <= n-k+1, "bug");
00684 Pol::copy (r, rj, nj); Pol::clear (r + nj, k - nj);
00685 Pol::copy (t, v2, nv2); Pol::clear (t + nv2, n-k+1 - nv2);
00686 }
00687 else {
00688 VERIFY (nj > k, "bug");
00689 nat nq= njm1 - nj + 1, len_q= aligned_size<C,V> (nq);
00690 C* q= mmx_new<C> (len_q);
00691 Pol::quo_rem (q, rjm1, rj, njm1, nj);
00692 Pol::trim (rjm1, njm1);
00693 Pol::mul (tp, q, v2, nq, nv2);
00694 Pol::clear (u2 + nu2, n1 - nu2); nu2= n1;
00695 Pol::sub (u2, tp, nq + nv2 - 1);
00696 Pol::trim (u2, nu2);
00697 VERIFY (njm1 <= k, "bug");
00698 VERIFY (nu2 <= n-k+1, "bug");
00699 Pol::copy (r, rjm1, njm1); Pol::clear (r + njm1, k - njm1);
00700 Pol::copy (t, u2, nu2); Pol::clear (t + nu2, n-k+1 - nu2);
00701 mmx_delete<C> (q, len_q);
00702 }
00703 mmx_delete<C> (rjm1, len_r); mmx_delete<C> (rj, len_r);
00704 mmx_delete<C> (s1, l1); mmx_delete<C> (tp, len_tp);
00705 mmx_delete<C> (u1, l2); mmx_delete<C> (u2, l1);
00706 mmx_delete<C> (v1, l2); mmx_delete<C> (v2, l1);
00707 }
00708
00709 };
00710
00711
00712
00713
00714
00715 template<typename V>
00716 struct polynomial_evaluate_threshold {};
00717
00718 template<typename V, typename BV>
00719 struct implementation<polynomial_evaluate,V,polynomial_dicho<BV> >:
00720 public implementation<polynomial_divide,V>
00721 {
00722 typedef polynomial_evaluate_threshold<polynomial_dicho<BV> > Th;
00723 typedef implementation<vector_linear,V> Vec;
00724 typedef implementation<polynomial_divide,V> Pol;
00725 typedef implementation<polynomial_evaluate,V,BV> Fallback;
00726
00727 TMPL static inline void
00728 factorials (C* dest, nat n) {
00729 if (n > 0) dest[0]= C(1);
00730 for (nat i=1; i<n; i++)
00731 dest[i]= C(i) * dest[i-1];
00732 }
00733
00734 TMPL static void
00735 shift (C* dest, const C* s, const C& sh, nat n) {
00736 if (n <= 1 || sh == 0) Pol::copy (dest, s, n);
00737 else {
00738 nat l= aligned_size<C,V> (5 * n);
00739 C* u = mmx_new<C> (l);
00740 C* v = u + n;
00741 C* w = v + n;
00742 C* facts= w + n + n;
00743 factorials (facts, n);
00744 Vec::mul (u, s, facts, n);
00745 Vec::vec_reverse (u, n);
00746 Vec::set (v, 1, n);
00747 if (sh != 0) Pol::q_difference (v, v, sh, n);
00748 Vec::div (v, facts, n);
00749 Pol::mul (w, u, v, n, n);
00750 Vec::vec_reverse (w, n);
00751 Vec::div (dest, w, facts, n);
00752 mmx_delete<C> (u, l);
00753 }
00754 }
00755
00756 TMPL static inline C
00757 evaluate (const C* p, const C& x, nat l) {
00758 return Fallback::evaluate (p, x, l);
00759 }
00760
00761 TMPL static void
00762 q_binomial (C* dest, const C& q, nat mu) {
00763 dest[mu]= 1;
00764 for (nat i=1; i<=mu; i++)
00765 dest[mu-i]= (C (mu+1-i) * dest[mu+1-i] * q) / C(i);
00766 }
00767
00768 TMPL static void
00769 expand (C** v, const C* p, const C* x, const nat* nu, nat n, nat k) {
00770 nat tot= 0;
00771 for (nat i=0; i<k; i++)
00772 tot += nu[i] + 1;
00773 nat* d= mmx_new<nat> (k);
00774 nat l= aligned_size<C,V> (tot << 1);
00775 C* q= mmx_new<C> (l);
00776 C* r= q + tot;
00777 nat off= 0;
00778 for (nat i=0; i<k; i++) {
00779 q_binomial (q + off, -x[i], nu[i]);
00780 d[i]= nu[i] + 1;
00781 off += d[i];
00782 }
00783 multi_mod (r, p, q, d, n, k);
00784 off= 0;
00785 for (nat i=0; i<k; i++) {
00786 shift (v[i], r + off, x[i], nu[i]);
00787 off += d[i];
00788 }
00789 mmx_delete<C> (q, l);
00790 mmx_delete<nat> (d, k);
00791 }
00792
00793 #define C typename scalar_type_helper<Polynomial >::val
00794
00795 struct _vector_sort_by_increasing_degree_op {
00796 TMPLP static bool
00797 op (const Polynomial& p, const Polynomial& q) {
00798 return deg (p) < deg (q); }
00799 TMPLP static bool
00800 not_op (const Polynomial& p, const Polynomial& q) {
00801 return deg (p) >= deg (q); }
00802 };
00803
00804 template<typename Op, typename Polynomial> static inline vector<Polynomial>
00805 _multi_rem (const Polynomial& p, const vector<Polynomial>& q) {
00806 if (p == 0) return vector<Polynomial> (Polynomial(C(0)), N(q));
00807 nat n= degree (p);
00808 vector<Polynomial> sorted_q (q), r (Polynomial(C(0)), N(q));
00809 vector<nat> sigma;
00810 sort_leq<_vector_sort_by_increasing_degree_op> (sorted_q, sigma);
00811 nat start= 0, sum= 0;
00812 for (nat i= 0; i < N(q); i++) {
00813 sum += degree (sorted_q[i]);
00814 if (sum > n / 2 || i+1 == N(q)) {
00815 Crt_polynomial_transformer(Polynomial)
00816 crter (range (sorted_q, start, i+1));
00817 vector<Polynomial> tmp; direct_crt (tmp, p, crter);
00818 for (nat j= start; j < i+1; j++) r[sigma[j]]= tmp[j-start];
00819 start= i+1; sum= 0;
00820 }
00821 }
00822 return r; }
00823
00824 TMPLP static inline vector<Polynomial>
00825 multi_rem (const Polynomial& p, const vector<Polynomial>& q) {
00826 return _multi_rem<rem_op> (p, q); }
00827
00828 TMPLP static inline vector<Polynomial>
00829 multi_prem (const Polynomial& p, const vector<Polynomial>& q) {
00830 return _multi_rem<prem_op> (p, q); }
00831
00832 TMPLP static vector<Polynomial>
00833 multi_gcd (const Polynomial& P, const vector<Polynomial>& Q) {
00834 return binary_map<gcd_op> (rem (P, Q), Q); }
00835
00836 TMPLP static Polynomial
00837 annulator (const Vector& x) {
00838 ASSERT (is_non_scalar (x), "non-scalar xector expected");
00839 if (N(x) == 0) return 1;
00840 vector<Polynomial> q (Polynomial (), N(x));
00841 Polynomial z (C(1), 1);
00842 for (nat i= 0; i < N(x); i++) q[i]= z - x[i];
00843 Crt_polynomial_transformer(Polynomial) crter (q);
00844 return * moduli_product (crter); }
00845
00846 TMPLP static inline Vector
00847 evaluate (const Polynomial& p, const Vector& x) {
00848 ASSERT (is_non_scalar (x), "non-scalar vector expected");
00849 if (N(x) == 0) return Vector (C(0), 0);
00850 vector<Polynomial> q (Polynomial (), N(x));
00851 Polynomial z (C(1), 1);
00852 for (nat i= 0; i < N(x); i++) q[i]= z - x[i];
00853 vector<Polynomial> tmp= multi_rem (p, q);
00854 Vector r (C(0), N(x));
00855 for (nat i= 0; i < N(x); i++) r[i]= tmp[i][0];
00856 return r;
00857 }
00858
00859 TMPLP static inline Polynomial
00860 tevaluate (const Vector& v, const Vector& x, nat l) {
00861 ASSERT (is_non_scalar (x), "non-scalar vector expected");
00862 if (l == 0) return Polynomial (0);
00863 nat n= N(x), ll= aligned_size<C,V> (l);
00864 vector<Polynomial> q (Polynomial (), n);
00865 Polynomial z (C(1), 1);
00866 for (nat i= 0; i < N(v); i++) q[i]= Polynomial (1) - x[i] * z;
00867 Crt_polynomial_transformer(Polynomial) crter (q);
00868 Polynomial num (combine_crt (v, crter)), den (* moduli_product (crter));
00869
00870 C* tmp= mmx_new<C> (ll), * inv= mmx_new<C> (ll);
00871 Pol::clear (tmp, l); Pol::copy (tmp, seg(den), min (N(den), l));
00872 Pol::invert_lo (inv, tmp, l);
00873 Polynomial b (inv, l, ll), c (num * b);
00874 for (nat i= 0; i < l; i++) tmp[i]= c[i];
00875 return Polynomial (tmp, l, ll);
00876 }
00877
00878 TMPLP static Polynomial
00879 interpolate (const Vector& v, const Vector& x) {
00880 ASSERT (is_non_scalar (x), "non-scalar vector expected");
00881 ASSERT (N(v) == N(x), "dimensions don't match");
00882 if (N(x) == 0) return Polynomial (0);
00883 vector<Polynomial> q (Polynomial (), N(x));
00884 Polynomial z (C(1), 1);
00885 for (nat i= 0; i < N(x); i++) q[i]= z - x[i];
00886 Crt_polynomial_transformer(Polynomial) crter (q);
00887 Polynomial ans; inverse_crt (ans, as<vector<Polynomial> > (v), crter);
00888 return ans;
00889 }
00890
00891 TMPLP static Vector
00892 tinterpolate (const Polynomial& p, const Vector& x) {
00893 ASSERT (is_non_scalar (x), "non-scalar vector expected");
00894 ASSERT (N(p) <= N(x), "dimensions don't match");
00895 nat n= N(x);
00896 if (n == 0) return Vector (C(0), 0);
00897 vector<Polynomial> q (Polynomial (), n);
00898 Polynomial z (C(1), 1);
00899 for (nat i= 0; i < n; i++) q[i]= z - x[i];
00900 Crt_polynomial_transformer(Polynomial) crter (q);
00901 Polynomial rp (lshiftz (reverse (p), (int) (n - N(p))));
00902 Polynomial den (* moduli_product (crter));
00903 vector<Polynomial> a; direct_crt (a, range (den * rp, n, 2*n), crter);
00904 vector<Polynomial> b; direct_crt (b, derive (den), crter);
00905 Vector v (C(0), n); for (nat i= 0; i < n; i++) v[i]= a[i][0] / b[i][0];
00906 return v;
00907 }
00908
00909 #undef C
00910 };
00911
00912 #undef TMPL
00913 #undef TMPLP
00914 #undef Vector
00915 }
00916 #endif //__MMX__POLYNOMIAL_DICHO__HPP