From 63b8744917d187bde75688d577319077a23bc561 Mon Sep 17 00:00:00 2001 From: Pavel Krajcevski Date: Fri, 21 Mar 2014 01:16:45 -0400 Subject: [PATCH] Pull out multiplication routines so that they can be specialized if need be --- Base/include/MatrixBase.h | 88 ++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/Base/include/MatrixBase.h b/Base/include/MatrixBase.h index 3cbbdf7..80facc4 100644 --- a/Base/include/MatrixBase.h +++ b/Base/include/MatrixBase.h @@ -87,45 +87,6 @@ namespace FasTC { return result; } - // Matrix multiplication - template - MatrixBase MultiplyMatrix(const MatrixBase<_T, nCols, nTarget> &m) const { - MatrixBase result; - for(int r = 0; r < nRows; r++) - for(int c = 0; c < nTarget; c++) { - result(r, c) = 0; - for(int j = 0; j < nCols; j++) { - result(r, c) += (*this)(r, j) * m(j, c); - } - } - return result; - } - - // Vector multiplication -- treat vectors as Nx1 matrices... - template - VectorBase MultiplyVectorLeft(const VectorBase<_T, nRows> &v) const { - VectorBase result; - for(int j = 0; j < nCols; j++) { - result(j) = 0; - for(int r = 0; r < nRows; r++) { - result(j) += (*this)(r, j) * v(r); - } - } - return result; - } - - template - VectorBase MultiplyVectorRight(const VectorBase<_T, nCols> &v) const { - VectorBase result; - for(int r = 0; r < nRows; r++) { - result(r) = 0; - for(int j = 0; j < nCols; j++) { - result(r) += (*this)(r, j) * v(j); - } - } - return result; - } - // Transposition MatrixBase Transpose() const { MatrixBase result; @@ -148,6 +109,49 @@ namespace FasTC { } }; + // Matrix multiplication + template + inline MatrixBase + MultiplyMatrix(const MatrixBase &a, + const MatrixBase<_T, nCols, nTarget> &b) { + MatrixBase result; + for(int r = 0; r < nRows; r++) + for(int c = 0; c < nTarget; c++) { + result(r, c) = 0; + for(int j = 0; j < nCols; j++) { + result(r, c) += a(r, j) * b(j, c); + } + } + return result; + } + + // Vector multiplication -- treat vectors as Nx1 matrices... + template + inline VectorBase VectorMatrixMultiply(const VectorBase &v, + const MatrixBase<_T, nRows, nCols> &m) { + VectorBase r; + for(int j = 0; j < nCols; j++) { + r(j) = 0; + for(int i = 0; i < nRows; i++) { + r(j) += m(i, j) * v(i); + } + } + return r; + } + + template + inline VectorBase MatrixVectorMultiply(const MatrixBase<_T, nRows, nCols> &m, + const VectorBase &v) { + VectorBase r; + for(int j = 0; j < nRows; j++) { + r(j) = 0; + for(int i = 0; i < nCols; i++) { + r(j) += m(j, i) * v(i); + } + } + return r; + } + template class VectorTraits > { public: @@ -191,7 +195,7 @@ namespace FasTC { MultSwitch(const TypeOne &a, const TypeTwo &b) : m_A(a), m_B(b) { } - ResultType GetMultiplication() const { return m_A.MultiplyVectorRight(m_B); } + ResultType GetMultiplication() const { return MatrixVectorMultiply(m_A, m_B); } }; template @@ -209,7 +213,7 @@ namespace FasTC { MultSwitch(const TypeOne &a, const TypeTwo &b) : m_A(a), m_B(b) { } - ResultType GetMultiplication() const { return m_B.MultiplyVectorLeft(m_A); } + ResultType GetMultiplication() const { return VectorMatrixMultiply(m_A, m_B); } }; template @@ -227,7 +231,7 @@ namespace FasTC { MultSwitch(const TypeOne &a, const TypeTwo &b) : m_A(a), m_B(b) { } - ResultType GetMultiplication() const { return m_A.MultiplyMatrix(m_B); } + ResultType GetMultiplication() const { return MultiplyMatrix(m_A, m_B); } }; // Outer product...