Pull out multiplication routines so that they can be specialized if need be

This commit is contained in:
Pavel Krajcevski 2014-03-21 01:16:45 -04:00
parent 31c799a02a
commit 63b8744917

View File

@ -87,45 +87,6 @@ namespace FasTC {
return result; return result;
} }
// Matrix multiplication
template<typename _T, const int nTarget>
MatrixBase<T, nRows, nTarget> MultiplyMatrix(const MatrixBase<_T, nCols, nTarget> &m) const {
MatrixBase<T, nRows, nTarget> result;
for(int r = 0; r < nRows; r++)
for(int c = 0; c < nTarget; c++) {
result(r, c) = 0;
for(int j = 0; j < nCols; j++) {
result(r, c) += (*this)(r, j) * m(j, c);
}
}
return result;
}
// Vector multiplication -- treat vectors as Nx1 matrices...
template<typename _T>
VectorBase<T, nCols> MultiplyVectorLeft(const VectorBase<_T, nRows> &v) const {
VectorBase<T, nCols> result;
for(int j = 0; j < nCols; j++) {
result(j) = 0;
for(int r = 0; r < nRows; r++) {
result(j) += (*this)(r, j) * v(r);
}
}
return result;
}
template<typename _T>
VectorBase<T, nRows> MultiplyVectorRight(const VectorBase<_T, nCols> &v) const {
VectorBase<T, nRows> result;
for(int r = 0; r < nRows; r++) {
result(r) = 0;
for(int j = 0; j < nCols; j++) {
result(r) += (*this)(r, j) * v(j);
}
}
return result;
}
// Transposition // Transposition
MatrixBase<T, nCols, nRows> Transpose() const { MatrixBase<T, nCols, nRows> Transpose() const {
MatrixBase<T, nCols, nRows> result; MatrixBase<T, nCols, nRows> result;
@ -148,6 +109,49 @@ namespace FasTC {
} }
}; };
// Matrix multiplication
template<typename T, typename _T, const int nRows, const int nCols, const int nTarget>
inline MatrixBase<T, nRows, nTarget>
MultiplyMatrix(const MatrixBase<T, nRows, nCols> &a,
const MatrixBase<_T, nCols, nTarget> &b) {
MatrixBase<T, nRows, nTarget> result;
for(int r = 0; r < nRows; r++)
for(int c = 0; c < nTarget; c++) {
result(r, c) = 0;
for(int j = 0; j < nCols; j++) {
result(r, c) += a(r, j) * b(j, c);
}
}
return result;
}
// Vector multiplication -- treat vectors as Nx1 matrices...
template<typename T, typename _T, int nRows, int nCols>
inline VectorBase<T, nCols> VectorMatrixMultiply(const VectorBase<T, nRows> &v,
const MatrixBase<_T, nRows, nCols> &m) {
VectorBase<T, nCols> r;
for(int j = 0; j < nCols; j++) {
r(j) = 0;
for(int i = 0; i < nRows; i++) {
r(j) += m(i, j) * v(i);
}
}
return r;
}
template<typename T, typename _T, int nRows, int nCols>
inline VectorBase<T, nRows> MatrixVectorMultiply(const MatrixBase<_T, nRows, nCols> &m,
const VectorBase<T, nCols> &v) {
VectorBase<T, nRows> r;
for(int j = 0; j < nRows; j++) {
r(j) = 0;
for(int i = 0; i < nCols; i++) {
r(j) += m(j, i) * v(i);
}
}
return r;
}
template<typename T, const int N, const int M> template<typename T, const int N, const int M>
class VectorTraits<MatrixBase<T, N, M> > { class VectorTraits<MatrixBase<T, N, M> > {
public: public:
@ -191,7 +195,7 @@ namespace FasTC {
MultSwitch(const TypeOne &a, const TypeTwo &b) MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { } : m_A(a), m_B(b) { }
ResultType GetMultiplication() const { return m_A.MultiplyVectorRight(m_B); } ResultType GetMultiplication() const { return MatrixVectorMultiply(m_A, m_B); }
}; };
template<typename TypeOne, typename TypeTwo> template<typename TypeOne, typename TypeTwo>
@ -209,7 +213,7 @@ namespace FasTC {
MultSwitch(const TypeOne &a, const TypeTwo &b) MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { } : m_A(a), m_B(b) { }
ResultType GetMultiplication() const { return m_B.MultiplyVectorLeft(m_A); } ResultType GetMultiplication() const { return VectorMatrixMultiply(m_A, m_B); }
}; };
template<typename TypeOne, typename TypeTwo> template<typename TypeOne, typename TypeTwo>
@ -227,7 +231,7 @@ namespace FasTC {
MultSwitch(const TypeOne &a, const TypeTwo &b) MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { } : m_A(a), m_B(b) { }
ResultType GetMultiplication() const { return m_A.MultiplyMatrix(m_B); } ResultType GetMultiplication() const { return MultiplyMatrix(m_A, m_B); }
}; };
// Outer product... // Outer product...