00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #include <basix/port.hpp>
00014 #include <algebramix/crt_integer.hpp>
00015 namespace mmx {
00016 #if defined (__GNU_MP__)
00017 typedef mp_limb_t C;
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040 static inline nat
00041 mpn_size (const C* src, nat n) {
00042 while (n > 0 && src[n-1] == 0) n--;
00043 return n;
00044 }
00045
00046 static inline void
00047 mpn_clear (C* dest, nat n) {
00048 for (; n != 0; n--) {
00049 *dest= 0;
00050 dest++;
00051 }
00052 }
00053
00054 static inline void
00055 mpn_copy (C* dest, const C* src, nat n) {
00056 for (; n != 0; n--) {
00057 *dest= *src;
00058 dest++; src++;
00059 }
00060 }
00061
00062 static inline void
00063 mpn_copy (C* dest, const C* src, nat n1, nat n2) {
00064 mpn_copy (dest, src, n2);
00065 mpn_clear (dest + n2, n1 - n2);
00066 }
00067
00068 static void
00069 mpn_cofactor (C* d, const C* s1, const C* s2, nat n1, nat n2) {
00070
00071
00072 mpz_t src1;
00073 mpz_t src2;
00074 mpz_t g;
00075 mpz_t c1;
00076 mpz_t c2;
00077 src1->_mp_alloc= n1;
00078 src1->_mp_size = mpn_size (s1, n1);
00079 src1->_mp_d = const_cast<C*> (s1);
00080 src2->_mp_alloc= n2;
00081 src2->_mp_size = mpn_size (s2, n2);
00082 src2->_mp_d = const_cast<C*> (s2);
00083 mpz_init (g);
00084 mpz_init (c1);
00085 mpz_init (c2);
00086 mpz_gcdext (g, c1, c2, src1, src2);
00087 if (mpz_sgn (c1) < 0) mpz_add (c1, c1, src2);
00088 mpn_copy (d, c1->_mp_d, n2, c1->_mp_size);
00089 mpz_clear (g);
00090 mpz_clear (c1);
00091 mpz_clear (c2);
00092 }
00093
00094 static void
00095 build_converters (C* enc, const C* src, nat n, nat inc) {
00096
00097
00098
00099
00100
00101
00102
00103 if (n == 1) *enc= (C) *src;
00104 else {
00105 nat n2= n >> 1;
00106 nat n1= n - n2;
00107 C* dec= enc + (inc >> 1);
00108 build_converters (enc + inc, src, n1, inc);
00109 build_converters (enc + inc + n1, src + n1, n2, inc);
00110 mpn_mul (enc, enc + inc, n1, enc + inc + n1, n2);
00111 mpn_cofactor (dec, enc + inc, enc + inc + n1, n1, n2);
00112 }
00113 }
00114
00115 vector<C>
00116 mpz_setup_crt (const vector<C>& mods) {
00117 nat n= N(mods);
00118 if (n == 0) return vector<C> ();
00119 nat inc= n << 1;
00120 nat steps= 1;
00121 for (nat i= n-1; i != 0; steps++, i >>= 1);
00122 vector<C> cv= fill<C> (0, steps * inc);
00123 C* enc= seg (cv);
00124 build_converters (enc, seg (mods), n, inc);
00125 return cv;
00126 }
00127
00128 integer
00129 mpz_moduli_product (const vector<C>& mods, const vector<mp_limb_t>& cv) {
00130 nat n= N(mods);
00131 if (n == 0) return 1;
00132 integer i= raw_integer (n);
00133 mpn_copy ((*i)->_mp_d, seg (cv), n);
00134 (*i)->_mp_size= mpn_size (seg (cv), n);
00135 return i;
00136 }
00137
00138
00139
00140
00141
00142 inline const C* seg (const integer& i) { return (*i)->_mp_d; }
00143 inline C* seg (integer& i) { return (*i)->_mp_d; }
00144
00145 static void
00146 mpn_mod (C* dest, C* tmp,
00147 const C* s1, const C* s2, nat n1, nat n2) {
00148
00149 nat eff_n1= n1, eff_n2= n2;
00150 while (eff_n1 > 0 && s1[eff_n1-1] == 0) eff_n1--;
00151 while (eff_n2 > 0 && s2[eff_n2-1] == 0) eff_n2--;
00152 if (eff_n1 < eff_n2) mpn_copy (dest, s1, n2, eff_n1);
00153 else {
00154 mpn_tdiv_qr (tmp, dest, 0, s1, eff_n1, s2, eff_n2);
00155 if (n2 > eff_n2) mpn_clear (dest + eff_n2, n2 - eff_n2);
00156 }
00157 }
00158
00159 static void
00160 encode (C* dest, C* tmp, const C* enc, nat n, nat inc) {
00161 if (n == 1) return;
00162 nat n2= n >> 1;
00163 nat n1= n - n2;
00164 mpn_copy (tmp, dest, n);
00165 mpn_mod (dest, tmp + n, tmp, enc, n, n1);
00166 mpn_mod (dest + n1, tmp + n, tmp, enc + n1, n, n2);
00167 encode (dest, tmp, enc + inc, n1, inc);
00168 encode (dest + n1, tmp, enc + n1 + inc, n2, inc);
00169 }
00170
00171 void
00172 mpz_encode_crt (C* dest, const integer& src,
00173 const vector<C>& mods, const vector<C>& cv)
00174 {
00175 nat n= N(mods);
00176 if (n == 0) return;
00177 const C* enc= seg (cv);
00178 C* tmp = mmx_new<C> (n << 1);
00179 mpn_copy (dest, seg (src), n, limb_size (src));
00180 encode (dest, tmp, enc + (n << 1), n, n << 1);
00181 if (sign (src) < 0)
00182 for (nat i=0; i<n; i++)
00183 if (dest[i] != 0)
00184 dest[i]= mods[i] - dest[i];
00185 mmx_delete<C> (tmp, n << 1);
00186
00187
00188
00189
00190
00191
00192
00193 }
00194
00195 static bool
00196 mpn_is_zero (const C* src, nat n) {
00197 for (nat i= 0; i<n; i++)
00198 if (src[i] != 0) return false;
00199 return true;
00200 }
00201
00202 static void
00203 mpn_reconstruct (C* dest, C* temp,
00204 const C* x1, const C* x2,
00205 const C* m1, const C* m2, const C* c1,
00206 nat n1, nat n2)
00207 {
00208 mpn_copy (temp, x2, n1, n2);
00209 C borrow= mpn_sub_n (dest, temp, x1, n1);
00210 if (borrow) mpn_sub_n (dest, x1, temp, n1);
00211 mpn_mul (temp, dest, n1, c1, n2);
00212 mpn_mod (dest, temp + n1 + n2, temp, m2, n1 + n2, n2);
00213 if (borrow && !mpn_is_zero (dest, n2)) {
00214 mpn_copy (temp, dest, n2);
00215 mpn_sub_n (dest, m2, temp, n2);
00216 }
00217 mpn_mul (temp, m1, n1, dest, n2);
00218 mpn_add (dest, temp, n1 + n2, x1, n1);
00219 }
00220
00221 static void
00222 decode (C* dest, C* tmp, const C* src,
00223 const C* enc, nat n, nat inc)
00224 {
00225 if (n == 1) *dest= (C) *src;
00226 else {
00227
00228
00229 nat n2= n >> 1;
00230 nat n1= n - n2;
00231 const C* dec= enc + (inc >> 1);
00232 decode (dest, tmp, src, enc + inc, n1, inc);
00233 decode (dest + n1, tmp, src + n1, enc + n1 + inc, n2, inc);
00234 mpn_copy (tmp, dest, n);
00235 mpn_reconstruct (dest, tmp + n, tmp, tmp + n1,
00236 enc + inc, enc + inc + n1, dec, n1, n2);
00237
00238 }
00239 }
00240
00241 integer
00242 mpz_decode_crt (const C* src,
00243 const vector<C>& mods, const vector<C>& cv) {
00244 nat n= N(mods);
00245 if (n == 0) return 0;
00246 integer i= raw_integer (n);
00247 const C* enc= seg (cv);
00248 C* tmp = mmx_new<C> (n << 2);
00249 decode (seg (i), tmp, src, enc, n, n << 1);
00250 mpn_sub_n (tmp, enc, seg (i), n);
00251 if (mpn_cmp (seg (i), tmp, n) <= 0)
00252 (*i)->_mp_size= mpn_size (seg (i), n);
00253 else {
00254 mpn_copy (seg (i), tmp, n);
00255 (*i)->_mp_size= -mpn_size (seg (i), n);
00256 }
00257 mmx_delete<C> (tmp, n << 2);
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269 return i;
00270 }
00271
00272 #endif // __GNU_MP__
00273 }