00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX_SERIES_IMPLICIT_HPP
00014 #define __MMX_SERIES_IMPLICIT_HPP
00015 #include <algebramix/series_vector.hpp>
00016 namespace mmx {
00017 #define TMPL_DEF template<typename C, typename V=typename Series_variant(C)>
00018 #define TMPL template<typename C, typename V>
00019 #define Series series<C,V>
00020 #define Series_rep series_rep<C,V>
00021 #define VC vector<C>
00022 #define VSeries series<vector<C>,V>
00023 #define VSeries_rep series_rep<vector<C>,V>
00024 #define UC unknown<C,V>
00025 #define UC_rep unknown_rep<C,V>
00026 #define USeries series<UC >
00027 #define USeries_rep series_rep<UC >
00028 #define Solver_rep solver_series_rep<C,V>
00029 TMPL class solver_series_rep;
00030
00031
00032
00033
00034
00035 TMPL_DEF
00036 class unknown_rep REP_STRUCT {
00037 public:
00038 Solver_rep* f;
00039 C b;
00040 C* s;
00041 nat l;
00042 nat i1;
00043 nat i2;
00044
00045 inline void normalize () {
00046 nat k=i1;
00047 while (i1 < i2 && is_exact_zero (s[i2-k-1])) i2--;
00048 while (i1 < i2 && is_exact_zero (s[i1-k])) i1++;
00049 if (i1 != k) {
00050
00051 nat d= i1-k;
00052 C* s2= mmx_new<C> (i2-i1);
00053 for (nat i=0; i<i2-i1; i++) s2[i]= s[i+d];
00054 mmx_delete<C> (s, l);
00055 s= s2;
00056 l= i2-i1;
00057 }
00058 }
00059
00060 inline unknown_rep ():
00061 f (NULL), b (), s (mmx_new<C> (0)), l (0), i1 (0), i2 (0) {}
00062 inline unknown_rep (const C& b2):
00063 f (NULL), b (b2), s (mmx_new<C> (0)), l (0), i1 (0), i2 (0) {}
00064 unknown_rep (Solver_rep* f2, const C& b2, C* ss, nat s2, nat e2):
00065 f (f2), b (b2), s (ss), l (e2 - s2), i1 (s2), i2 (e2) { normalize (); }
00066 inline virtual ~unknown_rep () {
00067 mmx_delete<C> (s, l); }
00068 };
00069
00070 TMPL_DEF
00071 class unknown {
00072 INDIRECT_PROTO_2 (unknown, unknown_rep, C, V)
00073 public:
00074 inline unknown ():
00075 rep (new UC_rep ()) {}
00076 template<typename T> inline unknown (const T& b):
00077 rep (new UC_rep (as<C> (b))) {}
00078 inline unknown (Solver_rep* f, const C& b, C* s, nat i1, nat i2):
00079 rep (new UC_rep (f, b, s, i1, i2)) {}
00080 };
00081 INDIRECT_IMPL_2 (unknown, unknown_rep, typename C, C, typename V, V)
00082
00083 TMPL inline bool
00084 is_known (const UC& c) {
00085 return c->i2 == c->i1;
00086 }
00087
00088 TMPL inline C
00089 known (const UC& c) {
00090 ASSERT (c->i2 == c->i1, "cast failed");
00091 return c->b;
00092 }
00093
00094
00095
00096
00097
00098 TMPL_DEF
00099 class solver_series_rep: public VSeries_rep {
00100 public:
00101 nat m;
00102 vector<USeries > eqs;
00103 nat cur_n;
00104 vector<UC> sys;
00105
00106 public:
00107 inline solver_series_rep (nat m, const vector<vector<C> >& init);
00108 virtual void Increase_order (nat l);
00109 virtual syntactic name_component (nat i);
00110 virtual vector<USeries > initialize () = 0;
00111 vector<USeries > me ();
00112 VC next ();
00113 };
00114
00115
00116
00117
00118
00119 TMPL bool
00120 is_exact_zero (const UC c) {
00121 return c->i1 == c->i2 && is_exact_zero (c->b);
00122 }
00123
00124 template<typename Op, typename C, typename V> nat
00125 unary_hash (const UC& c) {
00126 register nat i, h= 78460;
00127 if (c->i1 == c->i2) return Op::op (c->b) ^ h;
00128 h += (c->i1 << 3) ^ Op::op (c->b);
00129 for (i=0; i<c->i2 - c->i1; i++)
00130 h= (h<<1) ^ (h<<5) ^ (h>>27) ^ Op::op (c->s[i]);
00131 return h;
00132 }
00133
00134 template<typename Op,typename C,typename V> inline bool
00135 binary_test (const UC& c1, const UC& c2) {
00136 if (Op::not_op (c1->b, c2->b)) return false;
00137 if (c1->i1 == c1->i2 || c2->i1 == c2->i2)
00138 return c1->i1 == c1->i2 && c2->i1 == c2->i2;
00139 if (Op::not_op (c1->f, c2->f)) return false;
00140 if (Op::not_op (c1->i1, c2->i1)) return false;
00141 if (Op::not_op (c1->i2, c2->i2)) return false;
00142 for (nat i= c1->i1; i<c1->i2; i++)
00143 if (Op::not_op (c1->s[i - c1->i1], c2->s[i - c1->i1]))
00144 return false;
00145 return true;
00146 }
00147
00148 TMPL syntactic
00149 flatten (const UC& c) {
00150 syntactic sum= flatten (c->b);
00151 for (nat i=c->i1; i<c->i2; i++) {
00152 nat k= i / c->f->m;
00153 nat j= i % c->f->m;
00154 sum= sum + flatten (c->s[i-c->i1]) *
00155 access (c->f->name_component (j), flatten (k));
00156 }
00157 return sum;
00158 }
00159
00160 TRUE_IDENTITY_OP_SUGAR(TMPL,UC)
00161 EXACT_IDENTITY_OP_SUGAR(TMPL,UC)
00162
00163
00164
00165
00166
00167 TMPL UC
00168 operator - (const UC& c) {
00169
00170 if (is_exact_zero (c)) return c;
00171 nat n= c->i2 - c->i1;
00172 C* s= mmx_new<C> (n);
00173 for (nat i=0; i<n; i++)
00174 s[i]= -c->s[i];
00175 return UC (c->f, -c->b, s, c->i1, c->i2);
00176 }
00177
00178 TMPL UC
00179 operator + (const UC& c1, const C& c2) {
00180
00181 if (is_exact_zero (c1))
00182 return UC (c1->f, c2, mmx_new<C> (0), c1->i1, c1->i1);
00183 if (is_exact_zero (c2)) return c1;
00184 nat n= c1->i2 - c1->i1;
00185 C* s= mmx_new<C> (n);
00186 for (nat i=0; i<n; i++)
00187 s[i]= c1->s[i];
00188 return UC (c1->f, c1->b + c2, s, c1->i1, c1->i2);
00189 }
00190
00191 TMPL UC
00192 operator * (const UC& c1, const C& c2) {
00193
00194 if (is_exact_zero (c1)) return c1;
00195 if (is_exact_zero (c2))
00196 return UC (c1->f, C (0), mmx_new<C> (0), c1->i1, c1->i1);
00197 nat n= c1->i2 - c1->i1;
00198 C* s= mmx_new<C> (n);
00199 for (nat i=0; i<n; i++)
00200 s[i]= c1->s[i] * c2;
00201 return UC (c1->f, c1->b * c2, s, c1->i1, c1->i2);
00202 }
00203
00204 TMPL UC
00205 operator * (const C& c1, const UC& c2) {
00206
00207 if (is_exact_zero (c2)) return c2;
00208 if (is_exact_zero (c1))
00209 return UC (c2->f, C (0), mmx_new<C> (0), c2->i1, c2->i1);
00210 nat n= c2->i2 - c2->i1;
00211 C* s= mmx_new<C> (n);
00212 for (nat i=0; i<n; i++)
00213 s[i]= c1 * c2->s[i];
00214 return UC (c2->f, c1 * c2->b, s, c2->i1, c2->i2);
00215 }
00216
00217 TMPL UC
00218 operator + (const UC& c1, const UC& c2) {
00219
00220 if (c1->i1 == c1->i2) return c2 + known (c1);
00221 if (c2->i1 == c2->i2) return c1 + known (c2);
00222 ASSERT (c1->f == c2->f, "incompatible unknown coefficients");
00223 nat i1= min (c1->i1, c2->i1);
00224 nat i2= max (c1->i2, c2->i2);
00225 C* s= mmx_new<C> (i2-i1);
00226 for (nat i= i1; i<i2; i++)
00227 s[i-i1]=
00228 (i >= c1->i1 && i < c1->i2? c1->s[i - c1->i1]: C(0)) +
00229 (i >= c2->i1 && i < c2->i2? c2->s[i - c2->i1]: C(0));
00230 return UC (c1->f, c1->b + c2->b, s, i1, i2);
00231 }
00232
00233 TMPL UC
00234 operator - (const UC& c1, const UC& c2) {
00235
00236 if (c1->i1 == c1->i2) return (-c2) + known (c1);
00237 if (c2->i1 == c2->i2) return c1 + (-known (c2));
00238 ASSERT (c1->f == c2->f, "incompatible unknown coefficients");
00239 nat i1= min (c1->i1, c2->i1);
00240 nat i2= max (c1->i2, c2->i2);
00241 C* s= mmx_new<C> (i2-i1);
00242 for (nat i=i1; i<i2; i++)
00243 s[i-i1]=
00244 (i >= c1->i1 && i < c1->i2? c1->s[i - c1->i1]: C(0)) -
00245 (i >= c2->i1 && i < c2->i2? c2->s[i - c2->i1]: C(0));
00246 return UC (c1->f, c1->b - c2->b, s, i1, i2);
00247 }
00248
00249 TMPL UC
00250 substitute (const UC& c) {
00251
00252 if (c->i1 == c->i2 || c->i1 >= c->f->n * c->f->m) return c;
00253 nat i1= min (c->f->n * c->f->m, c->i2);
00254 nat d= i1 - c->i1;
00255 nat n= c->i2 - i1;
00256 C* s= mmx_new<C> (n);
00257 C b= c->b;
00258 for (nat i=0; i<d; i++) {
00259 nat k= (i + c->i1) / c->f->m;
00260 nat j= (i + c->i1) % c->f->m;
00261 b += c->s[i] * c->f->a[k][j];
00262 }
00263 for (nat i=0; i<n; i++)
00264 s[i]= c->s[i + d];
00265 return UC (c->f, b, s, i1, c->i2);
00266 }
00267
00268 TMPL UC
00269 operator * (const UC& c1, const UC& c2) {
00270
00271 if (is_exact_zero (c1)) return c1;
00272 if (is_exact_zero (c2)) return c2;
00273 if (c1->i1 == c1->i2) return known (c1) * c2;
00274 if (c2->i1 == c2->i2) return c1 * known (c2);
00275 UC c1b= substitute (c1);
00276 UC c2b= substitute (c2);
00277 if (c1b->i1 == c1b->i2) return known (c1b) * c2b;
00278 if (c2b->i1 == c2b->i2) return c1b * known (c2b);
00279 ERROR ("invalid product of unknown coefficients");
00280 }
00281
00282
00283
00284
00285
00286 TMPL void
00287 reduce (UC& c1, UC& c2) {
00288
00289
00290
00291 if (is_exact_zero (c1)) return;
00292 if (is_exact_zero (c2)) { swap (c1, c2); return; }
00293 ASSERT (c1->f == c2->f, "incompatible unknown coefficients");
00294 if (c1->i2 < c2->i2) return;
00295 if (c2->i2 < c1->i2) { swap (c1, c2); return; }
00296 if (better_pivot (c1->s[c1->i2 - 1 - c1->i1],
00297 c2->s[c2->i2 - 1 - c2->i1])) swap (c1, c2);
00298 C lambda= c1->s[c1->i2 - 1 - c1->i1] / c2->s[c2->i2 - 1 - c2->i1];
00299 nat i1= min (c1->i1, c2->i1);
00300 nat i2= c1->i2;
00301 C* s= mmx_new<C> (i2-i1-1);
00302 for (nat i= i1; i<i2-1; i++)
00303 s[i-i1]=
00304 (i >= c1->i1? c1->s[i - c1->i1]: C(0)) -
00305 lambda * (i >= c2->i1? c2->s[i - c2->i1]: C(0));
00306 c1= UC (c1->f, c1->b - lambda * c2->b, s, i1, i2-1);
00307 }
00308
00309 TMPL void
00310 insert_and_reduce (vector<UC >& sys, UC& c) {
00311 for (nat i=0; i<N(sys); i++)
00312 reduce (c, sys[i]);
00313 if (is_exact_zero (c)) return;
00314 ASSERT (c->i1 != c->i2, "contradictory equations");
00315 sys << c;
00316 }
00317
00318
00319
00320
00321
00322 TMPL_DEF
00323 class unknown_series_rep: public USeries_rep {
00324 Solver_rep* f;
00325 nat k;
00326 public:
00327 unknown_series_rep (Solver_rep* f2, nat k2):
00328 USeries_rep (format<UC > ()), f (f2), k (k2)
00329 {
00330 this->n= f->n;
00331 this->Set_order (this->n);
00332 for (nat i=0; i<f->n; i++) {
00333 C* s= mmx_new<C> (0);
00334 this->a[i]= UC (f, f->a[i][k], s, i, i);
00335 }
00336 }
00337 syntactic expression (const syntactic& z) const {
00338 return syn (f->name_component (k), z); }
00339 UC next () {
00340 C* s= mmx_new<C> (1);
00341 s[0]= C (1);
00342 return UC (f, C(0), s, this->n * f->m + k, this->n * f->m + k + 1); }
00343 };
00344
00345 TMPL USeries
00346 unknown_series (Solver_rep* f, nat k) {
00347 return (USeries_rep*) new unknown_series_rep<C> (f, k);
00348 }
00349
00350 template<typename C, typename V, typename UV>
00351 class known_series_rep: public Series_rep {
00352 public:
00353 series<UC,UV> f;
00354 inline known_series_rep (const series<UC,UV>& f2):
00355 Series_rep (format<C> ()), f (f2) {}
00356 inline syntactic expression (const syntactic& z) const {
00357 return syn ("known", flatten (f, z)); }
00358 inline void Increase_order (nat l) {
00359 Series_rep::Increase_order (l);
00360 increase_order (f, l); }
00361 inline C next () { return known (substitute (f[this->n])); }
00362 };
00363
00364 template<typename C, typename V, typename UV> Series
00365 known (const series<UC,UV>& f) {
00366 return (Series_rep*) new known_series_rep<C,V,UV> (f);
00367 }
00368
00369
00370
00371
00372
00373 template<typename C, typename V, typename UV>
00374 class subst_mul_series_rep:
00375 public implementation<series_abstractions,UV>
00376 ::template binary_series_rep<mul_op,UC,UV> {
00377 public:
00378 nat f_sh, g_sh;
00379 Series f_kn, g_kn, inner;
00380 inline subst_mul_series_rep (const USeries& f, const USeries& g):
00381 implementation<series_abstractions,UV>
00382 ::template binary_series_rep<mul_op,UC,UV> (f, g),
00383 f_sh (1), g_sh (1), f_kn (known (f)), g_kn (known (g)),
00384 inner (rshiftz (f_kn, (int) f_sh) * rshiftz (g_kn, (int) g_sh)) {}
00385 UC next () {
00386 if (this->n >= g_sh &&
00387 !is_known (substitute (this->f [this->n - g_sh]))) {
00388 g_sh = g_sh << 1;
00389 inner= rshiftz (f_kn, (int) f_sh) * rshiftz (g_kn, (int) g_sh);
00390 }
00391 if (this->n >= f_sh &&
00392 !is_known (substitute (this->g [this->n - f_sh]))) {
00393 f_sh = f_sh << 1;
00394 inner= rshiftz (f_kn, (int) f_sh) * rshiftz (g_kn, (int) g_sh);
00395 }
00396 if (this->n < f_sh + g_sh) {
00397 UC acc= this->f[0] * this->g[this->n];
00398 for (nat i=1; i<=this->n; i++)
00399 acc += this->f[i] * this->g[this->n-i];
00400 return acc;
00401 }
00402 else {
00403 UC acc= f_kn[0] * this->g[this->n] + this->f[this->n] * g_kn[0];
00404 for (nat i=1; i<f_sh; i++)
00405 acc += f_kn[i] * this->g[this->n-i];
00406 for (nat i=1; i<g_sh; i++)
00407 acc += this->f[this->n-i] * g_kn[i];
00408 return acc + inner[this->n - f_sh - g_sh];
00409 }
00410 }
00411 };
00412
00413 template<typename C, typename V, typename UV> series<UC,UV>
00414 operator * (const series<UC,UV>& f, const series<UC,UV>& g) {
00415 if (is_exact_zero (f) || is_exact_zero (g))
00416 return series<UC,UV> (CF(f));
00417 return (series_rep<UC,UV>*) new subst_mul_series_rep<C,V,UV> (f, g);
00418 }
00419
00420
00421
00422
00423
00424 TMPL inline
00425 Solver_rep::solver_series_rep (nat m2, const vector<vector<C> >& init):
00426 VSeries_rep (format<VC > ()), m (m2), cur_n (0)
00427 {
00428 this->n= N(init);
00429 this->Set_order (this->n);
00430 for (nat i=0; i<this->n; i++)
00431 this->a[i]= init[i];
00432 }
00433
00434 TMPL void
00435 Solver_rep::Increase_order (nat l) {
00436 VSeries_rep::Increase_order (l);
00437 for (nat i=0; i<N(eqs); i++)
00438 increase_order (eqs[i], l);
00439 }
00440
00441 TMPL syntactic
00442 Solver_rep::name_component (nat i) {
00443 if (m == 1) return syntactic ("f");
00444 else return access (syntactic ("f"), flatten (i+1));
00445 }
00446
00447 TMPL vector<USeries >
00448 Solver_rep::me () {
00449 vector<USeries > r= fill<USeries > (m);
00450 for (nat i=0; i<m; i++)
00451 r[i]= unknown_series (this, i);
00452 return r;
00453 }
00454
00455 TMPL VC
00456 Solver_rep::next () {
00457
00458 for (nat i=0; i<N(sys); i++)
00459 sys[i]= substitute (sys[i]);
00460 VC ret = fill<C> (m);
00461 nat done= 0;
00462 while (true) {
00463 ASSERT (cur_n < this->n + 100, "too large delay in implicit solve");
00464
00465 for (nat j=0; j<N(eqs); j++) {
00466 UC c= eqs[j][cur_n];
00467
00468 insert_and_reduce (sys, c);
00469 }
00470
00471 cur_n++;
00472 for (nat i=0; i<N(sys); i++)
00473 ASSERT (sys[i]->f == ((Solver_rep*) this) &&
00474 sys[i]->i1 >= this->n * m && sys[i]->i2 > this->n * m,
00475 "invalid situation during implicit solving");
00476 while (N(sys) > 0 && done < m) {
00477 UC c= sys[N(sys)-1];
00478 if (c->i2 <= this->n * m + done + 1) {
00479 nat j1 = c->i1 - this->n * m;
00480 nat j2 = min (done, c->i2 - this->n * m);
00481 C rhs= c->b;
00482 for (nat j=j1; j<j2; j++)
00483 rhs += c->s[j-j1] * ret[j];
00484 if (c->i2 <= this->n * m + done) {
00485 ASSERT (is_exact_zero (rhs), "contradictory equations"); }
00486 else {
00487 ret[done]= -rhs / c->s[done-j1];
00488
00489 done++;
00490 }
00491 sys.secure ();
00492 inside (sys) -> resize (N(sys) - 1);
00493 }
00494 else break;
00495 }
00496 if (done == m) return ret;
00497 }
00498 }
00499
00500 TMPL
00501 class solver_container_series_rep: public VSeries_rep {
00502 VSeries f;
00503 public:
00504 solver_container_series_rep (const VSeries& f2):
00505 VSeries_rep (CF(f2)), f (f2) {
00506 Solver_rep* rep= (Solver_rep*) f.operator -> ();
00507 rep->eqs= rep->initialize (); }
00508 ~solver_container_series_rep () {
00509 Solver_rep* rep= (Solver_rep*) f.operator -> ();
00510 rep->eqs= vector<USeries > (); }
00511 syntactic expression (const syntactic& z) const {
00512 return flatten (f, z); }
00513 virtual void Increase_order (nat l) {
00514 VSeries_rep::Increase_order (l);
00515 increase_order (f, l); }
00516 VC next () { return f[this->n]; }
00517 };
00518
00519 TMPL VSeries
00520 solver (const VSeries& f) {
00521 return (VSeries_rep*) new solver_container_series_rep<C,V> (f);
00522 }
00523
00524 #undef TMPL_DEF
00525 #undef TMPL
00526 #undef Series
00527 #undef Series_rep
00528 #undef VC
00529 #undef VSeries
00530 #undef VSeries_rep
00531 #undef UC
00532 #undef UC_rep
00533 #undef Iseries
00534 #undef Iseries_rep
00535 #undef Solver_rep
00536 }
00537 #endif // __MMX_SERIES_IMPLICIT_HPP