00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef __MMX__MATRIX_FIXED__HPP
00014 #define __MMX__MATRIX_FIXED__HPP
00015 #include <algebramix/matrix_naive.hpp>
00016
00017 namespace mmx {
00018
00019
00020
00021
00022
00023 template<typename Op, nat tp> struct div_type {};
00024
00025 template<typename Op, typename D, typename S1, typename S2,
00026 nat r, nat l, nat c>
00027 struct matrix_multiply_helper {
00028 static const nat tp= (r*c==1 || (r<l && c<l && l>8))? 3: (r<=c? 2: 1);
00029
00030 template<typename rr, typename cc> static inline void
00031 mul (D* d, const S1* m1, const S2* m2) {
00032 matrix_multiply_helper <div_type<Op,tp>, D, S1, S2, r, l, c>::
00033 template mul<rr,cc> (d, m1, m2);
00034 }
00035
00036 static inline void
00037 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) {
00038 matrix_multiply_helper <div_type<Op,tp>, D, S1, S2, r, l, c>::
00039 mul_stride (d, m1, m2, rr, ll);
00040 }
00041 };
00042
00043 template<typename Op, typename D, typename S1, typename S2,
00044 nat r, nat l, nat c>
00045 struct matrix_multiply_helper<div_type<Op,1>, D, S1, S2, r, l, c> {
00046
00047 static const nat r1= (r>>1), r2= r-r1;
00048
00049 template<nat rr, nat ll> static inline void
00050 mul (D* d, const S1* m1, const S2* m2) {
00051 {
00052 matrix_multiply_helper <Op, D, S1, S2, r1, l, c>::
00053 template mul<rr,ll> (d, m1, m2);
00054 matrix_multiply_helper <Op, D, S1, S2, r2, l, c>::
00055 template mul<rr,ll> (d + r1, m1 + r1, m2);
00056 }
00057 }
00058
00059 static inline void
00060 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) {
00061 {
00062 matrix_multiply_helper <Op, D, S1, S2, r1, l, c>::
00063 mul_stride (d, m1, m2, rr, ll);
00064 matrix_multiply_helper <Op, D, S1, S2, r2, l, c>::
00065 mul_stride (d + r1, m1 + r1, m2, rr, ll);
00066 }
00067 }
00068 };
00069
00070 template<typename Op, typename D, typename S1, typename S2,
00071 nat r, nat l, nat c>
00072 struct matrix_multiply_helper<div_type<Op,2>, D, S1, S2, r, l, c> {
00073
00074 static const nat c1= (c>>1), c2= c-c1;
00075
00076 template<nat rr, nat ll> static inline void
00077 mul (D* d, const S1* m1, const S2* m2) {
00078 {
00079 matrix_multiply_helper <Op, D, S1, S2, r, l, c1>::
00080 template mul<rr,ll> (d, m1, m2);
00081 matrix_multiply_helper <Op, D, S1, S2, r, l, c2>::
00082 template mul<rr,ll> (d + c1 * rr, m1, m2 + c1 * ll);
00083 }
00084 }
00085
00086 static inline void
00087 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) {
00088 {
00089 matrix_multiply_helper <Op, D, S1, S2, r, l, c1>::
00090 mul_stride (d, m1, m2, rr, ll);
00091 matrix_multiply_helper <Op, D, S1, S2, r, l, c2>::
00092 mul_stride (d + c1 * rr, m1, m2 + c1 * ll, rr, ll);
00093 }
00094 }
00095 };
00096
00097 template<typename Op, typename D, typename S1, typename S2,
00098 nat r, nat l, nat c>
00099 struct matrix_multiply_helper<div_type<Op,3>, D, S1, S2, r, l, c> {
00100
00101 static const nat l1= (l>>1), l2= l-l1;
00102 typedef typename Op::acc_op Acc;
00103
00104 template<nat rr, nat ll> static inline void
00105 mul (D* d, const S1* m1, const S2* m2) {
00106 {
00107 matrix_multiply_helper <Op , D, S1, S2, r, l1, c>::
00108 template mul<rr,ll> (d, m1, m2);
00109 matrix_multiply_helper <Acc, D, S1, S2, r, l2, c>::
00110 template mul<rr,ll> (d, m1 + l1 * rr, m2 + l1);
00111 }
00112 }
00113
00114 static inline void
00115 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) {
00116 {
00117 matrix_multiply_helper <Op , D, S1, S2, r, l1, c>::
00118 mul_stride (d, m1, m2, rr, ll);
00119 matrix_multiply_helper <Acc, D, S1, S2, r, l2, c>::
00120 mul_stride (d, m1 + l1 * rr, m2 + l1, rr, ll);
00121 }
00122 }
00123 };
00124
00125 template<typename Op, typename D, typename S1, typename S2>
00126 struct matrix_multiply_helper<Op, D, S1, S2, 1, 1, 1> {
00127 template<nat rr, nat ll> static inline void
00128 mul (D* d, const S1* m1, const S2* m2) {
00129 Op::set_op (*d, *m1, *m2);
00130 }
00131
00132 static inline void
00133 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) {
00134 (void) rr; (void) ll;
00135 Op::set_op (*d, *m1, *m2);
00136 }
00137 };
00138
00139 }
00140 #endif //__MMX__MATRIX_FIXED__HPP