00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__CRT_NAIVE__HPP
00014 #define __MMX__CRT_NAIVE__HPP
00015 #include <basix/vector.hpp>
00016 #include <numerix/modulus.hpp>
00017
00018 namespace mmx {
00019
00020
00021
00022
00023
00024 struct coprime_moduli_sequence_naive {
00025 template<typename M> static bool
00026 extend (vector<M>& v, nat i) { return false; }
00027 };
00028
00029 template<typename M>
00030 struct coprime_moduli_helper {
00031 typedef coprime_moduli_sequence_naive sequence;
00032 };
00033
00034 #define Coprime_moduli_variant(M) coprime_moduli_helper<M>::sequence
00035
00036 template<typename M, typename V= typename Coprime_moduli_variant(M)>
00037 struct coprime_moduli_sequence {
00038 vector<M> v;
00039
00040 coprime_moduli_sequence () : v () {}
00041 coprime_moduli_sequence (const coprime_moduli_sequence& x) : v (x.v) {}
00042
00043 M operator [] (nat i) {
00044 if (i >= N(v)) {
00045 bool b= V::extend (v, i+1);
00046 if (!b) return M(0);
00047 }
00048 return v[i]; }
00049 };
00050
00051 template<typename M, typename V> vector<M>
00052 range (coprime_moduli_sequence<M, V>& seq, nat beg, nat end) {
00053 if (end <= beg) return vector<M> (M (), 0);
00054 vector<M> v (M (), end - beg);
00055 for (nat i= beg; i < end; i++) {
00056 if (seq[i] == M(0)) return vector<M> ();
00057 v[i-beg]= seq[i];
00058 }
00059 return v;
00060 }
00061
00062
00063
00064
00065
00066 template<typename C, typename M=modulus<C>,
00067 typename W= typename Coprime_moduli_variant(M)>
00068 struct moduli_helper {
00069 template<typename V> static bool
00070 covering (vector<M, V>& v, nat s) {
00071 static coprime_moduli_sequence<M,W> seq;
00072 v= vector<M,V> (M (0), s);
00073 for (nat i= 0; i < s; i++) {
00074 v[i]= seq[i];
00075 if (v[i] == 0) return false;
00076 }
00077 return true; }
00078 };
00079
00080
00081
00082
00083
00084 template<typename C>
00085 struct std_crt_naive {
00086 typedef C base;
00087 typedef C modulus_base;
00088 typedef typename Modulus_variant(C) modulus_base_variant;
00089 typedef typename Modulus_variant(C) modulus_variant;
00090 };
00091
00092
00093
00094
00095
00096 #define Crt_naive_variant(C) crt_naive_variant_helper<C>::CV
00097
00098 struct crt_naive {};
00099
00100 template<typename C>
00101 struct crt_naive_variant_helper {
00102 typedef crt_naive CV;
00103 };
00104
00105
00106
00107
00108
00109 struct crt_project {};
00110
00111 template<typename V>
00112 struct implementation<crt_project,V,crt_naive> {
00113
00114 template<typename M> static inline typename M::base
00115 half (const M& P) {
00116
00117 return typename M::base (0); }
00118
00119 template<typename C, typename M> static inline typename M::base
00120 encode (const C& a, const M& p) {
00121 (void) p;
00122 return as<typename M::base> (a); }
00123
00124 template<typename C, typename M> static inline C
00125 decode (const C& a, const M& P, const C& H) {
00126 (void) P; (void) H;
00127 return a; }
00128
00129 template<typename C, typename M> static inline typename M::base
00130 mod (const C& a, const M& p) {
00131 return as<typename M::base> (rem (a, *p)); }
00132
00133 };
00134
00135
00136
00137
00138
00139 template<typename V> struct crt_signed {};
00140
00141 template<typename F, typename V, typename W>
00142 struct implementation<F,V,crt_signed<W> >:
00143 public implementation<F,V,W> {};
00144
00145 template<typename V,typename W>
00146 struct implementation<crt_project,V,crt_signed<W> > :
00147 public implementation<crt_project,V,W> {
00148
00149 template<typename M> static inline typename M::base
00150 half (const M& P) {
00151 return *P >> 1; }
00152
00153 template<typename C, typename M> static inline typename M::base
00154 encode (const C& a, const M& p) {
00155 typename M::base b;
00156 if (a < 0) {
00157 b= as<typename M::base> (-a); neg_mod (b, p);
00158 }
00159 else
00160 b= as<typename M::base> (a);
00161 return b; }
00162
00163 template<typename C, typename M> static inline C
00164 decode (const C& a, const M& P, const C& H) {
00165 return a > H ? a - *P : a; }
00166
00167 };
00168
00169
00170
00171
00172
00173 struct crt_transform {};
00174
00175 template<typename V>
00176 struct implementation<crt_transform,V,crt_naive> :
00177 public implementation<crt_project,V> {
00178 typedef implementation<crt_project,V> Crt;
00179
00180 template<typename C, typename M, typename I> static inline void
00181 direct (I* c, const C& a, const M* p, nat n) {
00182 for (nat i= 0; i < n; i++)
00183 c[i]= Crt::mod (a, p[i]); }
00184
00185 template<typename C, typename I> static inline void
00186 combine (C& a, const I* c, const C* q, nat n) {
00187 a= 0;
00188 for (nat i= 0; i < n; i++)
00189 mul_add (a, c[i], q[i]); }
00190
00191 template<typename C, typename M, typename I, typename K> static inline void
00192 inverse (C& a, const I* c, const M* p, const K& P,
00193 const C* q, const I* m, nat n) {
00194 a= 0; I t;
00195 for (nat i= 0; i < n; i++) {
00196 mul_mod (t, m[i], c[i], p[i]);
00197 mul_add (a, t, q[i]);
00198 }
00199 a= Crt::mod (a, P); }
00200 };
00201
00202
00203
00204
00205
00206 template<typename C, typename S=std_crt_naive<C>,
00207 typename V=typename Crt_naive_variant(C) >
00208 struct crt_naive_transformer : public S {
00209
00210 typedef typename S::modulus_base I;
00211 typedef modulus<I, typename S::modulus_base_variant> M;
00212 typedef modulus<C, typename S::modulus_variant> Modulus;
00213 typedef implementation<vector_linear, vector_naive> Vec;
00214 typedef implementation<crt_transform,V> Crt;
00215
00216 nat n;
00217 nat l_p, l_m, l_q;
00218 M* p;
00219 Modulus P;
00220 C H;
00221 C* q;
00222 I* m;
00223 bool product_done, comoduli_done, inverse_done;
00224
00225 inline void setup_product () {
00226 if (product_done) return;
00227 q= mmx_formatted_new<C> (l_q, format<C> ());
00228 for (nat i= 0; i < n; i++) q[i]= * p[i];
00229 P= Vec::template vec_unary_big<mul_op> (q, n);
00230 product_done= true; }
00231
00232 inline void setup_comoduli () {
00233 if (comoduli_done) return;
00234 setup_product ();
00235 for (nat i= 0; i < n; i++) q[i]= * P / * p[i];
00236 comoduli_done= true; }
00237
00238 inline void setup_inverse () {
00239 if (inverse_done) return;
00240 setup_comoduli ();
00241 H= Crt::half (P);
00242 m= mmx_formatted_new<I> (l_m, format<I> ());
00243 for (nat i= 0; i < n; i++) {
00244 I a (1), b;
00245 for (nat j= 0; j < n; j++) {
00246 reduce_mod (b, * p[j], p[i]);
00247 if (j != i) mul_mod (a, b, p[i]);
00248 }
00249 inv_mod (m[i], a, p[i]);
00250 ASSERT (m[i] != 0, "moduli must be pairwise coprime"); }
00251 inverse_done= true; }
00252
00253 template<typename NV,typename VV>
00254 inline crt_naive_transformer (const vector<NV,VV>& v, bool lazy= true)
00255 : n(N(v)) {
00256 l_p= default_aligned_size<M> (n);
00257 l_m= default_aligned_size<I> (n);
00258 l_q= default_aligned_size<C> (n);
00259 p= mmx_formatted_new<M> (l_p, format<M> ());
00260 for (nat i= 0; i < n; i++) p[i]= as<M> (v[i]);
00261 q= NULL; m= NULL; P= C(1);
00262 product_done= false; comoduli_done= false; inverse_done= false;
00263 if (! lazy) setup_inverse (); }
00264
00265 inline ~crt_naive_transformer () {
00266 mmx_delete<M> (p, l_p);
00267 if (q != NULL) mmx_delete<C> (q, l_q);
00268 if (m != NULL) mmx_delete<I> (m, l_m); }
00269
00270 inline M operator[] (nat i) const {
00271 VERIFY (i < n, "index out of range");
00272 return p[i]; }
00273
00274 inline nat size () const {
00275 return n; }
00276
00277 inline Modulus product () {
00278 setup_product (); return P; }
00279
00280 inline C comodulus (nat i) {
00281 VERIFY (i < n, "index out of range");
00282 setup_comoduli (); return q[i]; }
00283
00284 inline void direct_transform (I* c, const C& a) const {
00285 C b= Crt::encode (a, P);
00286 Crt::direct (c, b, p, n); }
00287
00288 inline void combine (C& a, const I* c) {
00289 setup_comoduli ();
00290 Crt::combine (a, c, q, n); }
00291
00292 inline void inverse_transform (C& a, const I* c) {
00293 setup_inverse ();
00294 Crt::inverse (a, c, p, P, q, m, n);
00295 a= Crt::decode (a, P, H); }
00296 };
00297
00298 template<typename C, typename M, typename V> inline nat
00299 N (const crt_naive_transformer<C,M,V>& crter) {
00300 return crter.size ();
00301 }
00302
00303
00304
00305
00306
00307 #define C typename Crter::base
00308 #define I typename Crter::modulus_base
00309 #define M modulus<typename Crter::modulus_base,\
00310 typename Crter::modulus_base_variant>
00311 #define Modulus modulus<C, typename Crter::modulus_variant>
00312
00313 template<typename Crter> inline Modulus
00314 moduli_product (Crter& crter) {
00315 return crter.product (); }
00316
00317 template<typename Crter> inline typename Crter::base
00318 comodulus (const Crter& crter, nat i) {
00319 return crter.comodulus (i);
00320 }
00321
00322 template<typename Crter> inline void
00323 direct_crt (I* dest, const C& s, Crter& crter) {
00324 crter.direct_transform (dest, s);
00325 }
00326
00327 template<typename Crter, typename W> inline void
00328 direct_crt (vector<I,W>& dest, const C& s, Crter& crter) {
00329 nat l= aligned_size<I,W> (N(crter));
00330 I* tmp= mmx_formatted_new<I> (l, CF(dest));
00331 crter.direct_transform (tmp, s);
00332 dest= vector<I,W> (tmp, N(crter), l, CF(dest));
00333 }
00334
00335 template<typename Crter> inline vector<I>
00336 direct_crt (const C& s, Crter& crter) {
00337 vector<I> dest;
00338 direct_crt (dest, s, crter);
00339 return dest;
00340 }
00341
00342 template<typename Crter> inline void
00343 combine_crt (C& d, const I* src, Crter& crter) {
00344 crter.combine (d, src);
00345 }
00346
00347 template<typename Crter, typename W> inline void
00348 combine_crt (C& d, const vector<I,W>& src, Crter& crter) {
00349 crter.combine (d, seg (src));
00350 }
00351
00352 template<typename Crter, typename W> inline C
00353 combine_crt (const vector<I,W>& src, Crter& crter) {
00354 C d; combine_crt (d, seg (src), crter);
00355 return d;
00356 }
00357
00358 template<typename Crter> inline void
00359 inverse_crt (C& d, const I* src, Crter& crter) {
00360 crter.inverse_transform (d, src);
00361 }
00362
00363 template<typename Crter, typename W> inline void
00364 inverse_crt (C& d, const vector<I,W>& src, Crter& crter) {
00365 crter.inverse_transform (d, seg (src));
00366 }
00367
00368 template<typename Crter, typename W> inline C
00369 inverse_crt (const vector<I,W>& src, Crter& crter) {
00370 C d; inverse_crt (d, seg (src), crter);
00371 return d;
00372 }
00373
00374 #undef C
00375 #undef I
00376 #undef M
00377 #undef Modulus
00378 }
00379 #endif //__MMX__CRT_NAIVE__HPP