00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX_MATRIX_CRT_HPP
00014 #define __MMX_MATRIX_CRT_HPP
00015 #include <numerix/modular.hpp>
00016 #include <algebramix/matrix.hpp>
00017 #include <algebramix/crt_blocks.hpp>
00018 namespace mmx {
00019
00020
00021
00022
00023
00024 template<typename C>
00025 struct modular_matrix_crt {
00026 template<typename M>
00027 class modulus_storage {
00028 static inline M& dyn_modulus () {
00029 static M modulus = M ();
00030 return modulus; }
00031 public:
00032 static inline void set_modulus (const M& p) { dyn_modulus () = p; }
00033 static inline M get_modulus () { return dyn_modulus (); }
00034 };
00035 };
00036
00037
00038
00039
00040
00041 template<typename V>
00042 struct matrix_crt: public V {
00043 typedef typename V::Vec Vec;
00044 typedef typename V::Naive Naive;
00045 typedef typename V::Positive Positive;
00046 typedef matrix_crt<typename V::No_simd> No_simd;
00047 typedef matrix_crt<typename V::No_thread> No_thread;
00048 typedef matrix_crt<typename V::No_scaled> No_scaled;
00049 };
00050
00051 template<typename F, typename V, typename W>
00052 struct implementation<F,V,matrix_crt<W> >:
00053 public implementation<F,V,W> {};
00054
00055
00056
00057
00058
00059 template<typename C>
00060 struct matrix_crt_multiply_helper {
00061 static const nat dimension_threshold= 7;
00062 static const nat ratio_threshold= 100;
00063
00064 typedef crt_naive_transformer<C> crt_transformer;
00065 typedef moduli_helper<C,
00066 modulus<typename crt_transformer::modulus_base,
00067 typename crt_transformer::modulus_base_variant> > moduli_sequence;
00068
00069 static nat size (const C* s1, nat s1_rs, nat s1_cs,
00070 const C* s2, nat s2_rs, nat s2_cs,
00071 nat r, nat l, nat c) {
00072 nat sz= 0;
00073 for (nat k= 0; k < l; k++) {
00074 nat sz1= 0, sz2= 0;
00075 const C* ss1= s1 + k * s1_cs;
00076 const C* ss2= s2 + k * s2_rs;
00077 for (nat i= 0; i < r; i++, ss1 += s1_rs) sz1= max (sz1, N (*ss1));
00078 for (nat j= 0; j < c; j++, ss2 += s2_cs) sz2= max (sz2, N (*ss2));
00079 sz= max (sz, sz1 + sz2);
00080 }
00081 return sz; }
00082 };
00083
00084
00085
00086
00087
00088 template<typename V, typename W>
00089 struct implementation<matrix_multiply,V,matrix_crt<W> >:
00090 public implementation<matrix_multiply_base,V>
00091 {
00092 typedef implementation<matrix_multiply,W> Mat;
00093
00094 template<typename Op, typename D, typename S1, typename S2>
00095 static inline void
00096 mul (D* d, const S1* s1, const S2* s2,
00097 nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00098 Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc); }
00099
00100 template<typename D, typename S1, typename S2>
00101 static inline void
00102 mul (D* d, const S1* s1, const S2* s2,
00103 nat r, nat l, nat c) {
00104 Mat::template mul<mul_op> (d, s1, s2, r, r, l, l, c, c); }
00105
00106 template<typename C, typename I, typename MV, typename Crter>
00107 static void
00108 mat_direct_crt (matrix<I,MV>* dest, const C* s,
00109 nat s_rs, nat s_cs, nat r, nat c, Crter& crter) {
00110 nat n= N(crter);
00111 for (nat k= 0; k < n; k++)
00112 dest[k]= matrix<I,MV> (I (), r, c);
00113 I* aux= mmx_new<I> (n);
00114 for (nat i= 0; i < r; i++)
00115 for (nat j= 0; j < c; j++) {
00116 direct_crt (aux, s[i * s_rs + j * s_cs], crter);
00117 for (nat k= 0; k < n ; k++) dest[k](i,j)= aux[k];
00118 }
00119 mmx_delete<I> (aux, n); }
00120
00121 template<typename C, typename Modulus, typename MW, typename MV, typename Crter>
00122 static void
00123 mat_direct_crt (matrix<modular<Modulus,MW>,MV>* dest, const C* s,
00124 nat s_rs, nat s_cs, nat r, nat c, Crter& crter) {
00125 typedef modular<Modulus,MW> Modular;
00126 typedef typename Modular::modulus::base I;
00127 nat n= N(crter);
00128 for (nat k= 0; k < n; k++)
00129 dest[k]= matrix<Modular,MV> (Modular (), r, c);
00130 I* aux= mmx_new<I> (n);
00131 for (nat i= 0; i < r; i++)
00132 for (nat j= 0; j < c; j++) {
00133 direct_crt (aux, s[i * s_rs + j * s_cs], crter);
00134 for (nat k= 0; k < n ; k++) dest[k](i,j)= Modular (aux[k], true);
00135 }
00136 mmx_delete<I> (aux, n); }
00137
00138 template<typename C, typename I, typename MV, typename Crter>
00139 static void
00140 mat_inverse_crt (C* d, nat d_rs, nat d_cs, nat r, nat c,
00141 const matrix<I,MV>* s, Crter& crter) {
00142 nat n= N(crter);
00143 I* aux= mmx_new<I> (n);
00144 for (nat i= 0; i < r; i++)
00145 for (nat j= 0; j < c; j++) {
00146 for (nat k= 0; k < n ; k++) aux[k]= s[k](i,j);
00147 inverse_crt (d[i * d_rs + j * d_cs], aux, crter);
00148 }
00149 mmx_delete<I> (aux, n); }
00150
00151 template<typename C, typename Modulus, typename MW, typename MV, typename Crter>
00152 static void
00153 mat_inverse_crt (C* d, nat d_rs, nat d_cs, nat r, nat c,
00154 const matrix<modular<Modulus,MW>,MV>* s, Crter& crter) {
00155 typedef modular<Modulus,MW> Modular;
00156 typedef typename Modular::modulus::base I;
00157 nat n= N(crter);
00158 I* aux= mmx_new<I> (n);
00159 for (nat i= 0; i < r; i++)
00160 for (nat j= 0; j < c; j++) {
00161 for (nat k= 0; k < n ; k++) aux[k]= * s[k](i,j);
00162 inverse_crt (d[i * d_rs + j * d_cs], aux, crter);
00163 }
00164 mmx_delete<I> (aux, n); }
00165
00166 template<typename C, typename Crter> static void
00167 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c,
00168 Crter& crter) {
00169 typedef typename Crter::modulus_base I;
00170 typedef modulus<I,typename Crter::modulus_base_variant> Modulus;
00171 typedef modular<Modulus,modular_matrix_crt<C> > Modular;
00172 typedef matrix<Modular> Matrix_modular;
00173 nat n= N(crter);
00174 Matrix_modular* mm1= mmx_new<Matrix_modular> (n);
00175 Matrix_modular* mm2= mmx_new<Matrix_modular> (n);
00176 Matrix_modular* mmd= mmx_new<Matrix_modular> (n);
00177 mat_direct_crt (mm1, s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00178 r, l, crter);
00179 mat_direct_crt (mm2, s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00180 l, c, crter);
00181 for (nat k= 0; k < n; k++) {
00182 Modular::set_modulus (crter[k]);
00183 mmd[k]= mm1[k] * mm2[k];
00184 }
00185 mat_inverse_crt (d, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c),
00186 r, c, mmd, crter);
00187 mmx_delete<Matrix_modular> (mm1, n);
00188 mmx_delete<Matrix_modular> (mm2, n);
00189 mmx_delete<Matrix_modular> (mmd, n); }
00190
00191 template<typename C, typename S, typename CV> static void
00192 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c,
00193 crt_naive_transformer<C,S,CV>& crter) {
00194 typedef implementation<crt_project,CV> Crt;
00195 typedef crt_naive_transformer<C,S,CV> Crter;
00196 typedef typename Crter::modulus_base I;
00197 typedef modulus<I,typename Crter::modulus_base_variant> Modulus;
00198 typedef modular<Modulus,modular_matrix_crt<C> > Modular;
00199 typedef typename Matrix_variant(Modular) MV;
00200 typedef implementation<matrix_multiply,MV> Mat_mod;
00201
00202 nat spc1= aligned_size<Modular,V> (r * l);
00203 nat spc2= aligned_size<Modular,V> (l * c);
00204 nat spcd= aligned_size<Modular,V> (r * c);
00205 nat spc= spc1 + spc2 + spcd;
00206 Modular* x1= mmx_new<Modular> (spc);
00207 Modular* x2= x1 + spc1;
00208 Modular* xd= x2 + spc2;
00209
00210 if (N(crter) == 1) {
00211 Modulus p (crter[0]); Modular::set_modulus (p);
00212 for (nat i= 0; i < r * l; i++)
00213 x1[i]= Modular (Crt::encode (s1[i], p), true);
00214 for (nat i= 0; i < l * c; i++)
00215 x2[i]= Modular (Crt::encode (s2[i], p), true);
00216 Mat_mod::mul (xd, x1, x2, r, l, c);
00217 for (nat i= 0; i < r * c; i++)
00218 d[i]= Crt::decode (C(* xd[i]), crter.P, crter.H);
00219 }
00220 else {
00221 for (nat k= 0; k < N(crter); k++) {
00222 Modulus p (crter[k]); Modular::set_modulus (p);
00223 for (nat i= 0; i < r * l; i++)
00224 x1[i]= Modular (Crt::mod (Crt::encode (s1[i], crter.P), p));
00225 for (nat i= 0; i < l * c; i++)
00226 x2[i]= Modular (Crt::mod (Crt::encode (s2[i], crter.P), p));
00227 Mat_mod::mul (xd, x1, x2, r, l, c);
00228 I m (crter.m[k]), t; C q (crter.q[k]);
00229 for (nat i= 0; i < r * c; i++) {
00230 mul_mod (t, m, * xd[i], p);
00231 if (k == 0) d[i]= t * q; else mul_add (d[i], t, q);
00232 }
00233 }
00234 for (nat i= 0; i < r * c; i++)
00235 d[i]= Crt::decode (Crt::mod (d[i], crter.P), crter.P, crter.H);
00236 }
00237 mmx_delete<Modular> (x1, spc); }
00238
00239 template<typename C, typename Low, typename High, nat s, typename CV>
00240 static void
00241 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c,
00242 crt_blocks_transformer<Low,High,s,CV>& crter) {
00243 typedef matrix<C,matrix_crt<W> > Matrix;
00244 nat n= crter.high -> size ();
00245 if (n == 1) {
00246 mul (d, s1, s2, r, l, c, * crter.low[0]);
00247 return;
00248 }
00249 Matrix* mm1= mmx_new<Matrix> (n);
00250 Matrix* mm2= mmx_new<Matrix> (n);
00251 Matrix* mmd= mmx_new<Matrix> (n);
00252 mat_direct_crt (mm1, s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00253 r, l, * crter.high);
00254 mat_direct_crt (mm2, s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00255 l, c, * crter.high);
00256 for (nat k= 0; k < n; k++)
00257 mul (tab(mmd[k]), tab(mm1[k]), tab(mm2[k]), r, l, c, * crter.low[k]);
00258 mat_inverse_crt (d, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c),
00259 r, c, mmd, * crter.high);
00260 mmx_delete<Matrix> (mm1, n);
00261 mmx_delete<Matrix> (mm2, n);
00262 mmx_delete<Matrix> (mmd, n); }
00263
00264 template<typename C> static void
00265 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c) {
00266 typedef matrix_crt_multiply_helper<C> Matrix_crt;
00267 typedef typename Matrix_crt::crt_transformer Crter;
00268 typedef typename Matrix_crt::moduli_sequence Sequence;
00269 typedef typename Crter::modulus_base I;
00270 typedef modulus<I,typename Crter::modulus_base_variant> Modulus;
00271 static const nat dim_thr= Matrix_crt::dimension_threshold;
00272 static const nat ratio_thr= Matrix_crt::ratio_threshold;;
00273
00274 if (r <= dim_thr || l <= dim_thr || c <= dim_thr) {
00275 Mat::mul (d, s1, s2, r, l, c);
00276 return;
00277 }
00278 nat sz= matrix_crt_multiply_helper<C>
00279 ::size (s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00280 s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00281 r, l, c);
00282 nat wd= (ratio_thr * sz) / 100;
00283
00284 if (r < wd || l < wd || c < wd) {
00285 Mat::mul (d, s1, s2, r, l, c);
00286 return;
00287 }
00288 vector<Modulus> mods;
00289 if (! Sequence::covering (mods, sz))
00290
00291 Mat::mul (d, s1, s2, r, l, c);
00292 else {
00293 Crter crter (mods, false);
00294 mul (d, s1, s2, r, l, c, crter); } }
00295
00296 };
00297
00298 }
00299 #endif // __MMX_MATRIX_CRT_HPP