00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX_MATRIX_QUOTIENT_HPP
00014 #define __MMX_MATRIX_QUOTIENT_HPP
00015 #include <numerix/rational.hpp>
00016 #include <algebramix/matrix_integer.hpp>
00017 namespace mmx {
00018 #define TMPL template<typename C>
00019 #define NC Numerator_type(C)
00020 #define DC Denominator_type(C)
00021
00022
00023
00024
00025
00026 template<typename V>
00027 struct matrix_quotient: public V {
00028 typedef typename V::Vec Vec;
00029 typedef typename V::Naive Naive;
00030 typedef matrix_quotient<typename V::Positive> Positive;
00031 typedef matrix_quotient<typename V::No_simd> No_simd;
00032 typedef matrix_quotient<typename V::No_thread> No_thread;
00033 typedef matrix_quotient<typename V::No_scaled> No_scaled;
00034 };
00035
00036 template<typename F, typename V, typename W>
00037 struct implementation<F,V,matrix_quotient<W> >:
00038 public implementation<F,V,W> {};
00039
00040 DEFINE_VARIANT (matrix_rational,
00041 matrix_quotient<matrix_naive>)
00042
00043 STMPL
00044 struct matrix_variant_helper<rational> {
00045 typedef matrix_rational MV;
00046 };
00047
00048
00049
00050
00051
00052 TMPL DC
00053 least_common_denominator (const C* s, nat str, nat n) {
00054 switch (n) {
00055 case 0: return 1;
00056 case 1: return denominator (s[0]);
00057 case 2: return lcm (denominator (s[0]), denominator (s[str]));
00058 case 3: return lcm (denominator (s[0]), lcm (denominator (s[str]),
00059 denominator (s[str<<1])));
00060 default:
00061 return lcm (least_common_denominator (s, str, n>>1),
00062 least_common_denominator (s + (n>>1) * str, str, n - (n>>1)));
00063 }
00064 }
00065
00066 TMPL void
00067 denominator_mul (NC* dest, const C* src, nat str, nat n, const DC& d) {
00068 for (; n != 0; n--, dest += str, src += str)
00069 *dest= numerator (*src) * (d / denominator (*src));
00070 }
00071
00072 TMPL void
00073 denominator_div (C* dest, const NC* src,
00074 nat row_str, nat rows,
00075 nat col_str, nat cols,
00076 const DC* row_d, const DC* col_d)
00077 {
00078 for (nat i=0; i < rows; i++, dest += row_str, src += row_str, row_d++) {
00079 C* Dest = dest;
00080 const NC* Src = src;
00081 const DC* Col_d= col_d;
00082 for (nat j=0; j < cols; j++, Dest += col_str, Src += col_str, Col_d++)
00083 *Dest= C (*Src) / ((*row_d) * (*Col_d));
00084 }
00085 }
00086
00087
00088
00089
00090
00091 template<typename V, typename W>
00092 struct implementation<matrix_multiply,V,matrix_quotient<W> >:
00093 public implementation<matrix_linear,V>
00094 {
00095 static const nat thr= 7;
00096
00097 TMPL static void
00098 mul (C* d, const C* s1, const C* s2,
00099 nat r, nat l, nat c)
00100 {
00101 typedef typename Matrix_variant(Numerator_type(C)) NV;
00102 typedef implementation<vector_linear,NV> Vec_N;
00103 typedef implementation<vector_linear,NV> Vec_D;
00104 typedef implementation<matrix_multiply,V> Mat;
00105 typedef implementation<matrix_multiply,NV> Mat_N;
00106 typedef implementation<matrix_multiply,NV> Mat_D;
00107 typedef implementation<matrix_multiply,typename V::Naive> Naive_Mat;
00108
00109 if (r <= thr || l <= thr || c <= thr)
00110 Naive_Mat::template mul<mul_op> (d, s1, s2, r, r, l, l, c, c);
00111 else {
00112 nat sd1 = aligned_size<DC,NV> (r);
00113 nat sd2 = aligned_size<DC,NV> (c);
00114 DC* d1 = mmx_new<DC> (sd1 + sd2);
00115 DC* d2 = d1 + sd1;
00116 nat sSrc1= aligned_size<NC,NV> (r*l);
00117 nat sSrc2= aligned_size<NC,NV> (c*l);
00118 nat sDest= aligned_size<NC,NV> (r*c);
00119 NC* Src1 = mmx_new<NC> (sSrc1 + sSrc2 + sDest);
00120 NC* Src2 = Src1 + sSrc1;
00121 NC* Dest = Src2 + sSrc2;
00122 for (nat i=0; i<r; i++)
00123 d1[i] = least_common_denominator (s1 + Mat::index (i, 0, r, l),
00124 Mat::index (0, 1, r, l), l);
00125
00126 for (nat j=0; j<c; j++)
00127 d2[j] = least_common_denominator (s2 + Mat::index (0, j, l, c),
00128 Mat::index (1, 0, l, c), l);
00129
00130
00131 for (nat i=0; i<r; i++)
00132 denominator_mul
00133 (Src1 + Mat::index (i, 0, r, l), s1 + Mat::index (i, 0, r, l),
00134 Mat::index (0, 1, r, l), l, d1[i]);
00135
00136 for (nat j=0; j<c; j++)
00137 denominator_mul
00138 (Src2 + Mat::index (0, j, l, c), s2 + Mat::index (0, j, l, c),
00139 Mat::index (1, 0, l, c), l, d2[j]);
00140
00141
00142 Mat_N::mul (Dest, Src1, Src2, r, l, c);
00143
00144
00145 denominator_div (d, Dest,
00146 Mat::index (1, 0, r, c), r,
00147 Mat::index (0, 1, r, c), c,
00148 d1, d2);
00149
00150 mmx_delete<DC> (d1, sd1 + sd2);
00151 mmx_delete<NC> (Src1, sSrc1 + sSrc2 + sDest);
00152 }
00153 }
00154
00155 };
00156
00157 #undef TMPL
00158 #undef NC
00159 #undef DC
00160 }
00161 #endif // __MMX_MATRIX_QUOTIENT_HPP
00162