mirror of
https://github.com/yuzu-emu/FasTC.git
synced 2024-11-28 00:24:18 +01:00
Pull out multiplication routines so that they can be specialized if need be
This commit is contained in:
parent
31c799a02a
commit
63b8744917
@ -87,45 +87,6 @@ namespace FasTC {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix multiplication
|
|
||||||
template<typename _T, const int nTarget>
|
|
||||||
MatrixBase<T, nRows, nTarget> MultiplyMatrix(const MatrixBase<_T, nCols, nTarget> &m) const {
|
|
||||||
MatrixBase<T, nRows, nTarget> result;
|
|
||||||
for(int r = 0; r < nRows; r++)
|
|
||||||
for(int c = 0; c < nTarget; c++) {
|
|
||||||
result(r, c) = 0;
|
|
||||||
for(int j = 0; j < nCols; j++) {
|
|
||||||
result(r, c) += (*this)(r, j) * m(j, c);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Vector multiplication -- treat vectors as Nx1 matrices...
|
|
||||||
template<typename _T>
|
|
||||||
VectorBase<T, nCols> MultiplyVectorLeft(const VectorBase<_T, nRows> &v) const {
|
|
||||||
VectorBase<T, nCols> result;
|
|
||||||
for(int j = 0; j < nCols; j++) {
|
|
||||||
result(j) = 0;
|
|
||||||
for(int r = 0; r < nRows; r++) {
|
|
||||||
result(j) += (*this)(r, j) * v(r);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename _T>
|
|
||||||
VectorBase<T, nRows> MultiplyVectorRight(const VectorBase<_T, nCols> &v) const {
|
|
||||||
VectorBase<T, nRows> result;
|
|
||||||
for(int r = 0; r < nRows; r++) {
|
|
||||||
result(r) = 0;
|
|
||||||
for(int j = 0; j < nCols; j++) {
|
|
||||||
result(r) += (*this)(r, j) * v(j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transposition
|
// Transposition
|
||||||
MatrixBase<T, nCols, nRows> Transpose() const {
|
MatrixBase<T, nCols, nRows> Transpose() const {
|
||||||
MatrixBase<T, nCols, nRows> result;
|
MatrixBase<T, nCols, nRows> result;
|
||||||
@ -148,6 +109,49 @@ namespace FasTC {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Matrix multiplication
|
||||||
|
template<typename T, typename _T, const int nRows, const int nCols, const int nTarget>
|
||||||
|
inline MatrixBase<T, nRows, nTarget>
|
||||||
|
MultiplyMatrix(const MatrixBase<T, nRows, nCols> &a,
|
||||||
|
const MatrixBase<_T, nCols, nTarget> &b) {
|
||||||
|
MatrixBase<T, nRows, nTarget> result;
|
||||||
|
for(int r = 0; r < nRows; r++)
|
||||||
|
for(int c = 0; c < nTarget; c++) {
|
||||||
|
result(r, c) = 0;
|
||||||
|
for(int j = 0; j < nCols; j++) {
|
||||||
|
result(r, c) += a(r, j) * b(j, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector multiplication -- treat vectors as Nx1 matrices...
|
||||||
|
template<typename T, typename _T, int nRows, int nCols>
|
||||||
|
inline VectorBase<T, nCols> VectorMatrixMultiply(const VectorBase<T, nRows> &v,
|
||||||
|
const MatrixBase<_T, nRows, nCols> &m) {
|
||||||
|
VectorBase<T, nCols> r;
|
||||||
|
for(int j = 0; j < nCols; j++) {
|
||||||
|
r(j) = 0;
|
||||||
|
for(int i = 0; i < nRows; i++) {
|
||||||
|
r(j) += m(i, j) * v(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, typename _T, int nRows, int nCols>
|
||||||
|
inline VectorBase<T, nRows> MatrixVectorMultiply(const MatrixBase<_T, nRows, nCols> &m,
|
||||||
|
const VectorBase<T, nCols> &v) {
|
||||||
|
VectorBase<T, nRows> r;
|
||||||
|
for(int j = 0; j < nRows; j++) {
|
||||||
|
r(j) = 0;
|
||||||
|
for(int i = 0; i < nCols; i++) {
|
||||||
|
r(j) += m(j, i) * v(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T, const int N, const int M>
|
template<typename T, const int N, const int M>
|
||||||
class VectorTraits<MatrixBase<T, N, M> > {
|
class VectorTraits<MatrixBase<T, N, M> > {
|
||||||
public:
|
public:
|
||||||
@ -191,7 +195,7 @@ namespace FasTC {
|
|||||||
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
||||||
: m_A(a), m_B(b) { }
|
: m_A(a), m_B(b) { }
|
||||||
|
|
||||||
ResultType GetMultiplication() const { return m_A.MultiplyVectorRight(m_B); }
|
ResultType GetMultiplication() const { return MatrixVectorMultiply(m_A, m_B); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename TypeOne, typename TypeTwo>
|
template<typename TypeOne, typename TypeTwo>
|
||||||
@ -209,7 +213,7 @@ namespace FasTC {
|
|||||||
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
||||||
: m_A(a), m_B(b) { }
|
: m_A(a), m_B(b) { }
|
||||||
|
|
||||||
ResultType GetMultiplication() const { return m_B.MultiplyVectorLeft(m_A); }
|
ResultType GetMultiplication() const { return VectorMatrixMultiply(m_A, m_B); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename TypeOne, typename TypeTwo>
|
template<typename TypeOne, typename TypeTwo>
|
||||||
@ -227,7 +231,7 @@ namespace FasTC {
|
|||||||
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
||||||
: m_A(a), m_B(b) { }
|
: m_A(a), m_B(b) { }
|
||||||
|
|
||||||
ResultType GetMultiplication() const { return m_A.MultiplyMatrix(m_B); }
|
ResultType GetMultiplication() const { return MultiplyMatrix(m_A, m_B); }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Outer product...
|
// Outer product...
|
||||||
|
Loading…
Reference in New Issue
Block a user