00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__MATRIX_NAIVE__HPP
00014 #define __MMX__MATRIX_NAIVE__HPP
00015 #include <basix/vector.hpp>
00016
00017 namespace mmx {
00018 #define TMPL template<typename C>
00019
00020
00021
00022
00023
00024 struct matrix_naive {
00025 typedef vector_naive Vec;
00026 typedef matrix_naive Naive;
00027 typedef matrix_naive Positive;
00028 typedef matrix_naive No_aligned;
00029 typedef matrix_naive No_simd;
00030 typedef matrix_naive No_thread;
00031 typedef matrix_naive No_scaled;
00032 };
00033
00034 template<typename C>
00035 struct matrix_variant_helper {
00036 typedef matrix_naive MV;
00037 };
00038
00039
00040
00041
00042
00043 struct matrix_defaults {};
00044
00045 template<typename V>
00046 struct implementation<matrix_defaults,V,matrix_naive> {
00047 static const nat def_rows = 0;
00048 static const nat def_cols = 0;
00049 static const nat init_rows= 1;
00050 static const nat init_cols= 1;
00051 };
00052
00053
00054
00055
00056
00057 template<typename V>
00058 struct implementation<vector_defaults,V,matrix_naive>:
00059 public implementation<vector_defaults,V,typename V::Vec> {};
00060
00061 template<typename V>
00062 struct implementation<vector_allocate,V,matrix_naive>:
00063 public implementation<vector_allocate,V,typename V::Vec> {};
00064
00065 template<typename V>
00066 struct implementation<vector_abstractions,V,matrix_naive>:
00067 public implementation<vector_abstractions,V,typename V::Vec> {};
00068
00069 template<typename V>
00070 struct implementation<vector_abstractions_stride,V,matrix_naive>:
00071 public implementation<vector_abstractions_stride,V,typename V::Vec> {};
00072
00073 template<typename V>
00074 struct implementation<vector_linear,V,matrix_naive>:
00075 public implementation<vector_linear,V,typename V::Vec> {};
00076
00077 template<typename C, typename V> inline nat
00078 aligned_size (nat r, nat c) {
00079 return aligned_size<C,V> (r * c); }
00080
00081 template<typename C> inline nat
00082 default_aligned_size (nat r, nat c) {
00083 return default_aligned_size<C> (r * c); }
00084
00085
00086
00087
00088
00089 struct matrix_vectorial {};
00090
00091 template<typename V>
00092 struct implementation<matrix_vectorial,V,matrix_naive>:
00093 public implementation<matrix_defaults,V>
00094 {
00095 typedef implementation<vector_linear,V> Vec;
00096
00097 static inline nat
00098 index (nat row, nat col, nat rows, nat cols) {
00099 (void) cols;
00100 return col * rows + row; }
00101
00102 TMPL
00103 static inline const C&
00104 entry (const C* m, nat row, nat col, nat rows, nat cols) {
00105 return * (m + index (row, col, rows, cols)); }
00106
00107
00108
00109
00110
00111
00112
00113
00114 template<typename Op, typename T, typename C> static inline void
00115 mat_unary_stride (T* dest, nat dest_rs, nat dest_cs,
00116 const C* src, nat src_rs, nat src_cs,
00117 nat rows, nat cols) {
00118 for (; cols != 0; dest += dest_cs, src += src_cs, cols--)
00119 Vec::template vec_unary_stride<Op> (dest, dest_rs, src, src_rs, rows); }
00120
00121 template<typename Op, typename T, typename C1, typename C2> static inline void
00122 mat_binary_stride (T* dest, nat dest_rs, nat dest_cs,
00123 const C1* src1, nat src1_rs, nat src1_cs,
00124 const C2* src2, nat src2_rs, nat src2_cs,
00125 nat rows, nat cols) {
00126 for (; cols != 0; dest += dest_cs, src1 += src1_cs, src2 += src2_cs, cols--)
00127 Vec::template vec_binary_stride<Op>
00128 (dest, dest_rs, src1, src1_rs, src2, src2_rs, rows); }
00129
00130 template<typename Op, typename T, typename C, typename X> static inline void
00131 mat_binary_scalar_stride (T* dest, nat dest_rs, nat dest_cs,
00132 const C* src1, nat src1_rs, nat src1_cs,
00133 const X& x, nat rows, nat cols) {
00134 for (; cols != 0; dest += dest_cs, src1 += src1_cs, cols--)
00135 Vec::template vec_binary_scalar_stride<Op>
00136 (dest, dest_rs, src1, src1_rs, x, rows); }
00137
00138 TMPL static inline void
00139 clear (C* dest, nat r, nat c) {
00140 Vec::clear (dest, r * c);
00141 }
00142
00143 TMPL static void
00144 set (C* dest, const C& c, nat rows, nat cols) {
00145 clear (dest, rows, cols);
00146 nat m= min (rows, cols), plus= index (1, 1, rows, cols);
00147 while (m != 0) {
00148 *dest= c;
00149 dest += plus; m--; } }
00150
00151 TMPL static void
00152 get_range (C* d, const C* s, nat r1, nat c1, nat r2, nat c2, nat r, nat c) {
00153 nat rr= r2 - r1, cc = c2 - c1;
00154 mat_unary_stride<id_op>
00155 (d, index (1, 0, rr, cc), index (0, 1, rr, cc),
00156 s + index (r1, c1, r, c), index (1, 0, r, c), index (0, 1, r, c),
00157 rr, cc);
00158 }
00159
00160 TMPL static void
00161 clear_range (C* d, nat r1, nat c1, nat r2, nat c2, nat r, nat c) {
00162 d += index (r1, c1, r, c);
00163 C* dd= d;
00164 nat rp= index (1, 0, r, c), cp= index (0, 1, r, c);
00165 for (nat j=0; j<(c2-c1); j++, dd += cp, d = dd)
00166 for (nat i=0; i<(r2-r1); i++, d += rp)
00167 clear_op::set_op (*d);
00168 }
00169
00170 TMPL static void
00171 set_range (C* d, const C& s, nat r1, nat c1, nat r2, nat c2, nat r, nat c) {
00172 d += index (r1, c1, r, c);
00173 C* dd= d;
00174 nat rp= index (1, 0, r, c), cp= index (0, 1, r, c);
00175 for (nat j=0; j<(c2-c1); j++, dd += cp, d = dd)
00176 for (nat i=0; i<(r2-r1); i++, d += rp)
00177 if (i == j) *d = s;
00178 else clear_op::set_op (*d);
00179 }
00180
00181 TMPL static void
00182 copy (C* dest, const C* s, nat r, nat c) {
00183 Vec::copy (dest, s, r * c); }
00184
00185 TMPL static void
00186 transpose (C* dest, const C* src, nat rows, nat cols) {
00187 mat_unary_stride<id_op>
00188 (dest, index (0, 1, cols, rows), index (1, 0, cols, rows),
00189 src, index (1, 0, rows, cols), index (0, 1, rows, cols),
00190 rows, cols); }
00191
00192 TMPL static inline void
00193 print (const port& out, const C* s, nat rows, nat cols) {
00194 out << "[ ";
00195 for (nat i=0; i < rows; i++) {
00196 for (nat j=0; j < cols; j++) {
00197 out << * (s + index (i, j, rows, cols));
00198 if (j + 1 != cols) out << ", ";
00199 }
00200 if (i + 1 != rows) out << ";" << lf;
00201 }
00202 out << " ]";
00203 }
00204
00205 };
00206
00207
00208
00209
00210
00211 struct matrix_linear {};
00212
00213 template<typename V>
00214 struct implementation<matrix_linear,V,matrix_naive>:
00215 public implementation<matrix_vectorial,V>
00216 {
00217 typedef implementation<matrix_vectorial,V> Mat;
00218 typedef implementation<vector_linear,V> Vec;
00219
00220 template<typename Op, typename C, typename X> static inline void
00221 col_unary_scalar (C* dest, const X& x, nat c1, nat r, nat rr, nat cc) {
00222 Vec::template vec_unary_scalar_stride<Op>
00223 (dest + Mat::index (0, c1, rr, cc), Mat::index (1, 0, rr, cc), x, r);
00224 }
00225
00226 template<typename Op, typename C, typename X> static inline void
00227 col_binary_scalar (C* dest, const X& x, nat c1, nat c2, nat r,
00228 nat rr, nat cc) {
00229 Vec::template vec_binary_scalar_stride<Op>
00230 (dest + Mat::index (0, c1, rr, cc), Mat::index (1, 0, rr, cc),
00231 dest + Mat::index (0, c2, rr, cc), Mat::index (1, 0, rr, cc), x, r);
00232 }
00233
00234 template<typename Op, typename C> static inline void
00235 col_binary_combine (C* dest, nat c1, nat c2, nat r, nat rr, nat cc) {
00236 Vec::template vec_binary_combine_stride<Op>
00237 (dest + Mat::index (0, c1, rr, cc), Mat::index (1, 0, rr, cc),
00238 dest + Mat::index (0, c2, rr, cc), Mat::index (1, 0, rr, cc), r);
00239 }
00240
00241 template<typename Op, typename C, typename X> static inline void
00242 row_unary_scalar (C* dest, const X& x, nat r1, nat c, nat rr, nat cc) {
00243 Vec::template vec_unary_scalar_stride<Op>
00244 (dest + Mat::index (r1, 0, rr, cc), Mat::index (0, 1, rr, cc), x, c);
00245 }
00246
00247 template<typename Op, typename C, typename X> static inline void
00248 row_binary_scalar (C* dest, const X& x, nat r1, nat r2, nat c,
00249 nat rr, nat cc) {
00250 Vec::template vec_binary_scalar_stride<Op>
00251 (dest + Mat::index (r1, 0, rr, cc), Mat::index (0, 1, rr, cc),
00252 dest + Mat::index (r2, 0, rr, cc), Mat::index (0, 1, rr, cc), x, c);
00253 }
00254
00255 template<typename Op, typename C> static inline void
00256 row_binary_combine (C* dest, nat r1, nat r2, nat c, nat rr, nat cc) {
00257 Vec::template vec_binary_combine_stride<Op>
00258 (dest + Mat::index (r1, 0, rr, cc), Mat::index (0, 1, rr, cc),
00259 dest + Mat::index (r2, 0, rr, cc), Mat::index (0, 1, rr, cc), c);
00260 }
00261
00262 TMPL static void col_mul (C* s, const C& sc, nat i, nat r, nat c) {
00263 ASSERT (i<c, "out of range");
00264 col_unary_scalar<rmul_op> (s, sc, i, r, r, c); }
00265 TMPL static void col_div (C* s, const C& sc, nat i, nat r, nat c) {
00266 ASSERT (i<c, "out of range");
00267 col_unary_scalar<rdiv_op> (s, sc, i, r, r, c); }
00268 TMPL static void row_mul (C* s, const C& sc, nat i, nat r, nat c) {
00269 ASSERT (i<r, "out of range");
00270 row_unary_scalar<rmul_op> (s, sc, i, c, r, c); }
00271 TMPL static void row_div (C* s, const C& sc, nat i, nat r, nat c) {
00272 ASSERT (i<r, "out of range");
00273 row_unary_scalar<rdiv_op> (s, sc, i, c, r, c); }
00274
00275 TMPL static void col_sweep (C* s, nat i, nat j, const C& sc, nat r, nat c) {
00276 ASSERT (i<c && j<c, "out of range");
00277 col_binary_scalar<mul_sub_op> (s, sc, i, j, r, r, c); }
00278 TMPL static void row_sweep (C* s, nat i, nat j, const C& sc, nat r, nat c) {
00279 ASSERT (i<r && j<r, "out of range");
00280 row_binary_scalar<mul_sub_op> (s, sc, i, j, c, r, c); }
00281 TMPL static void col_sweep (C* s, nat i, nat j, nat r, const C& sc,
00282 nat rr, nat cc) {
00283 col_binary_scalar<mul_sub_op> (s, sc, i, j, r, rr, cc); }
00284 TMPL static void row_sweep (C* s, nat i, nat j, nat c, const C& sc,
00285 nat rr, nat cc) {
00286 row_binary_scalar<mul_sub_op> (s, sc, i, j, c, rr, cc); }
00287
00288 TMPL static void col_swap (C* s, nat i, nat j, nat r, nat c) {
00289 ASSERT (i<c && j<c, "out of range");
00290 col_binary_combine<swap_op> (s, i, j, r, r, c); }
00291 TMPL static void row_swap (C* s, nat i, nat j, nat r, nat c) {
00292 ASSERT (i<r && j<r, "out of range");
00293 row_binary_combine<swap_op> (s, i, j, c, r, c); }
00294 TMPL static void col_swap (C* s, nat i, nat j, nat r, nat rr, nat cc) {
00295 col_binary_combine<swap_op> (s, i, j, r, rr, cc); }
00296 TMPL static void row_swap (C* s, nat i, nat j, nat c, nat rr, nat cc) {
00297 row_binary_combine<swap_op> (s, i, j, c, rr, cc); }
00298
00299 TMPL static void
00300 row_combine_sub (C* s, nat i, const C& si, nat j, const C& sj, nat r, nat c) {
00301 ASSERT (i<r && j<r, "out of range");
00302 C* iti= s + Mat::index (i, 0, r, c);
00303 C* itj= s + Mat::index (j, 0, r, c);
00304 nat plus= Mat::index (0, 1, r, c);
00305 while (c>0) {
00306 *iti= si * (*iti) - sj * (*itj);
00307 iti += plus; itj += plus; c--; } }
00308
00309 TMPL static void
00310 col_combine_sub (C* s, nat i, const C& si, nat j, const C& sj, nat r, nat c) {
00311 ASSERT (i<c && j<c, "out of range");
00312 C* iti= s + Mat::index (0, i, r, c);
00313 C* itj= s + Mat::index (0, j, r, c);
00314 nat plus= Mat::index (1, 0, r, c);
00315 while (r>0) {
00316 *iti= si * (*iti) - sj * (*itj);
00317 iti += plus; itj += plus; r--; } }
00318
00319 TMPL static bool
00320 row_is_zero (const C* s, nat i, nat r, nat c) {
00321 if (r == 0 || c == 0) return true;
00322 return Vec::template vec_binary_test_scalar_stride<equal_op>
00323 (s + Mat::index (i, 0, r, c), Mat::index (0, 1, r, c), promote (0, *s), c); }
00324
00325 TMPL static bool
00326 col_is_zero (const C* s, nat j, nat r, nat c) {
00327 if (r == 0 || c == 0) return true;
00328 return Vec::template vec_binary_test_scalar_stride<equal_op>
00329 (s + Mat::index (0, j, r, c), Mat::index (1, 0, r, c), promote (0, *s), r); }
00330
00331 };
00332
00333
00334
00335
00336
00337 template<typename V>
00338 struct matrix_multiply_threshold {};
00339
00340 struct matrix_multiply_base {};
00341 struct matrix_multiply {};
00342
00343 template<typename V>
00344 struct implementation<matrix_multiply_base,V,matrix_naive>:
00345 public implementation<matrix_linear,V>
00346 {
00347 typedef implementation<matrix_linear,V> Mat;
00348
00349 template<typename Op, typename C>
00350 struct clear_helper {
00351 static inline void op (C* d, nat r, nat rr, nat c, nat cc) {
00352 Mat::clear_range (d, 0, 0, r, c, rr, cc); }
00353 };
00354
00355 template<typename C>
00356 struct clear_helper<mul_add_op,C> {
00357 static inline void op (C* d, nat r, nat rr, nat c, nat cc) {}
00358 };
00359
00360 template<typename Op, typename C> static inline void
00361 clr (C* d, nat r, nat rr, nat c, nat cc) {
00362 clear_helper<Op,C>::op (d, r, rr, c, cc);
00363 }
00364
00365 template<typename Op, typename D, typename S1, typename S2>
00366 static inline void
00367 mul (D* d, const S1* s1, const S2* s2,
00368 nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00369 typedef typename Op::acc_op Acc;
00370 if (l == 0) clr<Op> (d, r, rr, c, cc);
00371 else {
00372 nat ri, ci, li;
00373 D *dr, *dc;
00374 const S1 *s1r, *s1c;
00375 const S2 *s2r, *s2c;
00376 nat drp = Mat::index (1, 0, rr, cc), dcp = Mat::index (0, 1, rr, cc);
00377 nat s1rp= Mat::index (1, 0, rr, ll), s1cp= Mat::index (0, 1, rr, ll);
00378 nat s2rp= Mat::index (1, 0, ll, cc), s2cp= Mat::index (0, 1, ll, cc);
00379 for (ri= r, dr= d, s1r= s1; ri!=0; ri--, dr += drp, s1r += s1rp)
00380 for (ci= c, dc= dr, s2c= s2; ci!=0; ci--, dc += dcp, s2c += s2cp) {
00381 D tmp= *dc;
00382 for (li= l, s1c= s1r, s2r= s2c; li==l; li--, s1c += s1cp, s2r += s2rp)
00383 Op ::set_op (tmp, *s1c, *s2r);
00384 for ( ; li!=0; li--, s1c += s1cp, s2r += s2rp)
00385 Acc::set_op (tmp, *s1c, *s2r);
00386 *dc= tmp;
00387 }
00388 }
00389 }
00390
00391 };
00392
00393 template<typename V>
00394 struct implementation<matrix_multiply,V,matrix_naive>:
00395 public implementation<matrix_multiply_base,V>
00396 {
00397 typedef implementation<matrix_multiply_base,V> Mat;
00398
00399 template<typename Op, typename D, typename S1, typename S2>
00400 static inline void
00401 mul (D* d, const S1* s1, const S2* s2,
00402 nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00403 Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc);
00404 }
00405
00406 template<typename D, typename S1, typename S2> static inline void
00407 mul (D* dest, const S1* m1, const S2* m2, nat r, nat l, nat c) {
00408 Mat::template mul<mul_op> (dest, m1, m2, r, r, l, l, c, c);
00409 }
00410
00411 };
00412
00413
00414
00415
00416
00417 struct matrix_iterate {};
00418
00419 template<typename V>
00420 struct implementation<matrix_iterate,V,matrix_naive>:
00421 public implementation<matrix_multiply,V>
00422 {
00423 typedef implementation<matrix_multiply,V> Mat;
00424
00425 TMPL static inline void
00426 iterate_mul (C** d, const C* s, const C* x,
00427 nat rs, nat cs, nat rx, nat cx, nat n) {
00428
00429
00430 VERIFY (cs == rx && rs == cs, "sizes do not match");
00431 if (rs == 0) return;
00432 Mat::get_range (d[0], x, 0, 0, rx, cx, rx, cx);
00433 for (nat i = 1; i < n; i++)
00434 Mat::mul (d[i], s, d[i-1], rs, cs, cx);
00435 }
00436
00437 TMPL static inline void
00438 project_iterate_mul (C** d, const C* y, const C* s, const C* x,
00439 nat ry, nat cy, nat rs, nat cs, nat rx, nat cx, nat n) {
00440
00441 VERIFY (cy == rs && rs == cs && cs == rx, "sizes do not match");
00442 if (n == 0) return;
00443 nat l= aligned_size<C,V> (rx * cx);
00444 C* buf= mmx_new<C> (l << 1);
00445 C* t1= buf, * t2= t1 + l;
00446 Mat::get_range (t1, x, 0, 0, rx, cx, rx, cx);
00447 Mat::mul (d[0], y, t1, ry, cy, cx);
00448 for (nat i = 1; i < n; i++) {
00449 Mat::mul (t2, s, t1, rs, cs, cx);
00450 Mat::mul (d[i], y, t2, ry, cy, cx);
00451 swap (t1, t2);
00452 }
00453 mmx_delete<C> (buf, l << 1);
00454 }
00455
00456 };
00457
00458
00459
00460
00461
00462 struct matrix_permute {};
00463
00464 template<typename V>
00465 struct implementation<matrix_permute,V,matrix_naive>:
00466 public implementation<matrix_multiply,V>
00467 {
00468 typedef implementation<matrix_multiply,V> Mat;
00469 typedef implementation<vector_linear,V> Vec;
00470
00471 TMPL static void
00472 col_permute (C* dest, const C* m, const nat* p, nat r, nat c) {
00473
00474 nat rs= Mat::index (1, 0, r, c);
00475 for (nat i=0; i<c; i++)
00476 Vec::template vec_unary_stride<id_op>
00477 (dest + Mat::index (0, i, r, c), rs,
00478 m + Mat::index (0, p[i], r, c), rs, r);
00479 }
00480
00481 TMPL static void
00482 row_permute (C* dest, const C* m, const nat* p, nat r, nat c) {
00483
00484 nat cs= Mat::index (0, 1, r, c);
00485 for (nat i=0; i<r; i++)
00486 Vec::template vec_unary_stride<id_op>
00487 (dest + Mat::index (i, 0, r, c), cs,
00488 m + Mat::index (p[i], 0, r, c), cs, c);
00489 }
00490
00491 TMPL static void
00492 col_permute (C* m, const nat* p, nat r, nat c, bool inv= false) {
00493
00494
00495 nat q[c];
00496 if (inv) for (nat i=0; i<c; i++) q[i]= p[i];
00497 else for (nat i=0; i<c; i++) q[p[i]]= i;
00498 for (nat i=0; i<c; ) {
00499 if (q[i] == i) i++;
00500 else {
00501 nat j= q[i], k= q[j];
00502 Mat::col_swap (m, i, j, r, c);
00503 q[i]= k; q[j]= j;
00504 }
00505 }
00506 }
00507
00508 TMPL static void
00509 row_permute (C* m, const nat* p, nat r, nat c, bool inv= false) {
00510
00511
00512 nat q[r];
00513 if (inv) for (nat i=0; i<c; i++) q[i]= p[i];
00514 else for (nat i=0; i<c; i++) q[p[i]]= i;
00515 for (nat i=0; i<r; ) {
00516 if (q[i] == i) i++;
00517 else {
00518 nat j= q[i], k= q[j];
00519 Mat::row_swap (m, i, q[i], r, c);
00520 q[i]= k; q[j]= j;
00521 }
00522 }
00523 }
00524
00525 };
00526
00527
00528
00529
00530
00531 template<typename C>
00532 struct pivot_helper {
00533 static inline bool
00534 better (const C& x1, const C& x2) { return false; }
00535 };
00536
00537 template<typename C> inline bool
00538 better_pivot (const C& x1, const C& x2) {
00539 return pivot_helper<C>::better (x1, x2);
00540 }
00541
00542
00543
00544
00545
00546 struct matrix_echelon {};
00547
00548 template<typename V>
00549 struct implementation<matrix_echelon,V,matrix_naive>:
00550 public implementation<matrix_multiply,V>
00551 {
00552 typedef implementation<matrix_multiply,V> Mat;
00553
00554 private:
00555
00556 TMPL static void
00557 col_echelon (C* m, C* k, nat ri, nat ci, nat rf, nat cf, nat rm, nat cm,
00558 C& num, C& den, bool reduced, nat* permut)
00559 {
00560
00561
00562
00563
00564
00565
00566
00567
00568
00569
00570 VERIFY ((ri < rm) && (ci < cm) && (rf <= rm) && (cf <= cm), "out of range");
00571 format<C> fm= get_format (m[0]);
00572 num= promote (1, fm); den= promote (1, fm);
00573 nat cb= ci;
00574 if (permut != NULL) for (nat i= 0; i < cm; i++) permut[i]= i;
00575 while (ri < rf && ci < cf) {
00576
00577 nat i= ri, j= ci;
00578 nat best_index;
00579 C best_val= promote (0, fm);
00580 for (i= ri; i<rf; i++) {
00581 best_index= cf;
00582 for (j= ci; j<cf; j++) {
00583 C next_val= Mat::entry (m, i, j, rm, cm);
00584 if (next_val != promote (0, fm))
00585 if (best_index == cf || better_pivot (next_val, best_val)) {
00586 best_index= j;
00587 best_val= next_val;
00588 }
00589 }
00590 if (best_index != cf) { j= best_index; break; }
00591 }
00592
00593 if (i<rf && j<cf) {
00594
00595 if (j != ci) {
00596 if (k != NULL) Mat::col_swap (k, ci, j, cm, cm);
00597 Mat::col_swap (m, ci, j, rm, cm);
00598 den= -den;
00599 if (permut != NULL) swap (permut[j], permut[ci]);
00600 }
00601 ri= i;
00602
00603 C p= Mat::entry (m, ri, ci, rm, cm);
00604 num *= p;
00605 for (nat index= reduced ? cb : ci + 1; index < cf; index++) {
00606 if (index == ci) continue;
00607 C t= Mat::entry (m, ri, index, rm, cm);
00608 if (k != NULL) Mat::col_sweep (k, index, ci, (C) (t/p), cm, cm);
00609 Mat::col_sweep (m, index, ci, (C) (t/p), rm, cm);
00610 }
00611 ri++; ci++;
00612 }
00613 else break;
00614 }
00615 }
00616
00617 public:
00618
00619 TMPL static inline void
00620 col_echelon (C* m, C* k, nat rm, nat cm, C& num, C& den,
00621 bool reduced= false, nat* permut= NULL) {
00622 if (k != NULL && cm != 0) Mat::set (k, promote (1, k[0]), cm, cm);
00623 col_echelon (m, k, 0, 0, rm, cm, rm, cm, num, den, reduced, permut);
00624 }
00625
00626 TMPL static inline void
00627 col_echelon (C* m, C* k, nat rm, nat cm,
00628 bool reduced= false, nat* permut= NULL) {
00629 if (k != NULL && cm != 0) Mat::set (k, promote (1, k[0]), cm, cm);
00630 C num, den;
00631 col_echelon (m, k, rm, cm, num, den, reduced, permut);
00632 }
00633
00634 };
00635
00636
00637
00638
00639
00640 struct matrix_determinant {};
00641
00642 template<typename V>
00643 struct implementation<matrix_determinant,V,matrix_naive>:
00644 public implementation<matrix_echelon,V>
00645 {
00646 typedef implementation<matrix_echelon,V> Mat;
00647
00648 TMPL static void
00649 det (C& r, const C* m, nat n) {
00650 if (n == 0) { set_as (r, 1); return; }
00651 nat len_c= aligned_size<C,V> (n * n);
00652 C* c= mmx_new<C> (len_c), * k= NULL;
00653 C num, den;
00654 Mat::copy (c, m, n, n);
00655 Mat::col_echelon (c, k, n, n, num, den);
00656 if (Mat::col_is_zero (c, n-1, n, n)) set_as (r, 0);
00657 else r= num / den;
00658 mmx_delete<C> (c, len_c);
00659 }
00660
00661 };
00662
00663
00664
00665
00666
00667 struct matrix_kernel {};
00668
00669 template<typename V>
00670 struct implementation<matrix_kernel,V,matrix_naive>:
00671 public implementation<matrix_echelon,V>
00672 {
00673 typedef implementation<matrix_echelon,V> Mat;
00674
00675 TMPL static nat
00676 kernel (C* k, const C* m, nat rm, nat cm) {
00677
00678 VERIFY (rm > 0 && cm > 0, "unexpected empty matrix");
00679 nat i, j, dim, len_c= aligned_size<C,V> (rm * cm);
00680 C* c= mmx_new<C> (len_c);
00681 Mat::copy (c, m, rm, cm);
00682 Mat::col_echelon (c, k, rm, cm);
00683 for (i=0; i<cm; i++)
00684 if (Mat::col_is_zero (c, i, rm, cm)) break;
00685 mmx_delete<C> (c, len_c);
00686 dim= cm - i;
00687 if (dim < cm)
00688 for (j=0; j < dim; j++)
00689 Mat::col_swap (k, j, i+j, cm, cm);
00690
00691 C x;
00692 for (j=0; j < dim; j++)
00693 for (i=0; i < cm; i++) {
00694 x= *(k + Mat::index (i, j, cm, cm));
00695 if (x != promote (0, x))
00696 Mat::col_div (k, x, j, cm, cm);
00697 }
00698 return dim;
00699 }
00700
00701 };
00702
00703
00704
00705
00706
00707 struct matrix_image {};
00708
00709 template<typename V>
00710 struct implementation<matrix_image,V,matrix_naive>:
00711 public implementation<matrix_echelon,V>
00712 {
00713 typedef implementation<matrix_echelon,V> Mat;
00714
00715 TMPL static nat
00716 image (C* k, const C* m, nat rm, nat cm) {
00717
00718 VERIFY (rm > 0 && cm > 0, "unexpected empty matrix");
00719 nat i, len_c= aligned_size<C,V> (rm * cm);
00720 C* c= mmx_new<C> (len_c);
00721 Mat::copy (c, m, rm, cm);
00722 Mat::col_echelon (c, (C*) NULL, rm, cm);
00723 for (i=0; i<cm; i++)
00724 if (Mat::col_is_zero (c, i, rm, cm)) break;
00725 Mat::get_range (k, c, 0, 0, rm, i, rm, cm);
00726
00727 mmx_delete<C> (c, len_c);
00728 return i;
00729 }
00730
00731 TMPL static nat
00732 rank (const C* m, nat rm, nat cm) {
00733 VERIFY (rm > 0 && cm > 0, "unexpected empty matrix");
00734 nat i, len_c= aligned_size<C,V> (rm * cm);
00735 C* c= mmx_new<C> (len_c);
00736 Mat::copy (c, m, rm, cm);
00737 Mat::col_echelon (c, (C*) NULL, rm, cm);
00738 for (i=0; i<cm; i++)
00739 if (Mat::col_is_zero (c, i, rm, cm)) break;
00740 mmx_delete<C> (c, len_c);
00741 return i;
00742 }
00743
00744 };
00745
00746
00747
00748
00749
00750 struct matrix_invert {};
00751
00752 template<typename V>
00753 struct implementation<matrix_invert,V,matrix_naive>:
00754 public implementation<matrix_echelon,V>
00755 {
00756 typedef implementation<matrix_echelon,V> Mat;
00757
00758 TMPL static void
00759 invert_lower_triangular (C* inv, const C* m, nat n) {
00760 for (nat i=0; i<n; i++) {
00761 C d= Mat::entry (m, i, i, n, n);
00762 ASSERT (d != promote (0, d), "non-invertible matrix");
00763 for (nat k=0; k<=i; k++) {
00764 C sum= promote (0, d);
00765 for (nat j=k; j<i; j++)
00766 sum -= Mat::entry (m, i, j, n, n) * Mat::entry (inv, j, k, n, n);
00767 if (i == k) sum += promote (1, d);
00768 inv[Mat::index (i, k, n, n)]= sum / d;
00769 }
00770 for (nat k=i+1; k<n; k++)
00771 inv[Mat::index (i, k, n, n)]= promote (0, d);
00772 }
00773 }
00774
00775 TMPL static void
00776 invert (C* k, const C* m, nat n) {
00777 nat l= aligned_size<C,V> (n * n);
00778 C* a= mmx_new<C> (l << 1);
00779 C* b= a + l;
00780 Mat::copy (k, m, n, n);
00781 Mat::col_echelon (k, b, n, n);
00782
00783
00784 invert_lower_triangular (a, k, n);
00785
00786 Mat::mul (k, b, a, n, n, n);
00787 mmx_delete<C> (a, l << 1);
00788 }
00789
00790 };
00791
00792
00793
00794
00795
00796 struct matrix_orthogonalization {};
00797
00798 template<typename V>
00799 struct implementation<matrix_orthogonalization,V,matrix_naive>:
00800 public implementation<matrix_multiply,V>
00801 {
00802 typedef implementation<matrix_linear,V> Mat;
00803 typedef implementation<vector_linear,V> Vec;
00804
00805 TMPL static inline C
00806 row_inn_prod (C* m, nat i, nat j, nat r, nat c) {
00807
00808 return Vec::template vec_binary_big_stride<mul_add_op> (
00809 m + Mat::index (i, 0, r, c), Mat::index (0, 1, r, c),
00810 m + Mat::index (j, 0, r, c), Mat::index (0, 1, r, c), c);
00811 }
00812
00813 TMPL static inline C
00814 col_inn_prod (C* m, nat i, nat j, nat r, nat c) {
00815
00816 return Vec::template vec_binary_big_stride<mul_add_op> (
00817 m + Mat::index (0, i, r, c), Mat::index (1, 0, r, c),
00818 m + Mat::index (0, j, r, c), Mat::index (1, 0, r, c), r);
00819 }
00820
00821 TMPL static void
00822 row_orthogonalize (C* m, nat r, nat c, C* n) {
00823
00824
00825
00826 for (nat i= 0; i < r; i++) {
00827 for (nat j= 0; j < i; j++) {
00828 if (n[j] == 0) continue;
00829 C mu= row_inn_prod (m, i, j, r, c) / n[j];
00830 Mat::row_combine_sub (m, i, C(1), j, mu, r, c);
00831 }
00832 n[i]= row_inn_prod (m, i, i, r, c);
00833 }
00834 }
00835
00836 TMPL static void
00837 col_orthogonalize (C* m, nat r, nat c, C* n) {
00838
00839 for (nat i= 0; i < c; i++) {
00840 for (nat j= 0; j < i; j++) {
00841 if (n[j] == 0) continue;
00842 C mu= col_inn_prod (m, i, j, r, c) / n[j];
00843 Mat::col_combine_sub (m, i, C(1), j, mu, r, c);
00844 }
00845 n[i]= col_inn_prod (m, i, i, r, c);
00846 }
00847 }
00848
00849 TMPL static void
00850 row_orthogonalize (C* m, nat r, nat c, C* l, C* n) {
00851
00852
00853
00854
00855 for (nat i= 0; i < r; i++) {
00856 for (nat j= 0; j < i; j++) {
00857 if (n[j] == 0) continue;
00858 C mu= row_inn_prod (m, i, j, r, c) / n[j];
00859 Mat::row_combine_sub (m, i, C(1), j, mu, r, c);
00860 *(l + Mat::index (i, j, r, r))= mu;
00861 }
00862 n[i]= row_inn_prod (m, i, i, r, c);
00863 *(l + Mat::index (i, i, r, r))= C(1);
00864 }
00865 }
00866
00867 TMPL static void
00868 col_orthogonalize (C* m, nat r, nat c, C* l, C* n) {
00869
00870 for (nat i= 0; i < c; i++) {
00871 for (nat j= 0; j < i; j++) {
00872 if (n[j] == 0) continue;
00873 C mu= col_inn_prod (m, i, j, r, c) / n[j];
00874 Mat::col_combine_sub (m, i, C(1), j, mu, r, c);
00875 *(l + Mat::index (j, i, c, c))= mu;
00876 }
00877 n[i]= col_inn_prod (m, i, i, r, c);
00878 *(l + Mat::index (i, i, c, c))= C(1);
00879 }
00880 }
00881
00882 TMPL static void
00883 row_orthonormalize (C* m, nat r, nat c) {
00884
00885 for (nat i= 0; i < r; i++) {
00886 for (nat j= 0; j < i; j++)
00887 Mat::row_combine_sub (m, i, C(1), j, row_inn_prod (m, i, j, r, c), r, c);
00888 C a= row_inn_prod (m, i, i, r, c);
00889 if (a != 0) Mat::row_div (m, sqrt (a), i, r, c);
00890 }
00891 }
00892
00893 TMPL static void
00894 col_orthonormalize (C* m, nat r, nat c) {
00895
00896 for (nat i= 0; i < c; i++) {
00897 for (nat j= 0; j < i; j++)
00898 Mat::col_combine_sub (m, i, C(1), j, col_inn_prod (m, i, j, r, c), r, c);
00899 C a= col_inn_prod (m, i, i, r, c);
00900 if (a != 0) Mat::col_div (m, sqrt (a), i, r, c);
00901 }
00902 }
00903
00904 TMPL static void
00905 row_orthonormalize (C* m, nat r, nat c, C* l) {
00906
00907
00908
00909 for (nat i= 0; i < r; i++) {
00910 for (nat j= 0; j < i; j++) {
00911 C mu= row_inn_prod (m, i, j, r, c);
00912 Mat::row_combine_sub (m, i, C(1), j, mu, r, c);
00913 *(l + Mat::index (i, j, r, r))= mu;
00914 }
00915 C a= row_inn_prod (m, i, i, r, c);
00916 if (a != 0) {
00917 C b= sqrt (a);
00918 Mat::row_div (m, b, i, r, c);
00919 *(l + Mat::index (i, i, r, r))= b;
00920 }
00921 }
00922 }
00923
00924 TMPL static void
00925 col_orthonormalize (C* m, nat r, nat c, C* l) {
00926
00927 for (nat i= 0; i < c; i++) {
00928 for (nat j= 0; j < i; j++) {
00929 C mu= col_inn_prod (m, i, j, r, c);
00930 Mat::col_combine_sub (m, i, C(1), j, mu, r, c);
00931 *(l + Mat::index (j, i, c, c))= mu;
00932 }
00933 C a= col_inn_prod (m, i, i, r, c);
00934 if (a != 0) {
00935 C b= sqrt (a);
00936 Mat::col_div (m, b, i, r, c);
00937 *(l + Mat::index (i, i, c, c))= b;
00938 }
00939 }
00940 }
00941
00942 };
00943
00944 #undef TMPL
00945 }
00946 #endif //__MMX__MATRIX_NAIVE__HPP