From 24803525b6bc0104837bab7aabc20b663319137d Mon Sep 17 00:00:00 2001 From: Max Kahnt <max.kahnt@fu-berlin.de> Date: Wed, 27 Jul 2016 12:38:11 +0200 Subject: [PATCH] Allow different types for transformations T1 and T2. --- dune/matrix-vector/transformmatrix.hh | 190 ++++++++++++++------------ 1 file changed, 105 insertions(+), 85 deletions(-) diff --git a/dune/matrix-vector/transformmatrix.hh b/dune/matrix-vector/transformmatrix.hh index 328f146..e2679ae 100644 --- a/dune/matrix-vector/transformmatrix.hh +++ b/dune/matrix-vector/transformmatrix.hh @@ -14,12 +14,14 @@ namespace MatrixVector { // add transformed matrix A += T1^t*B*T2 // ****************************************************** - template <class MatrixA, class MatrixB, class TransformationMatrix, - bool AisScalar, bool BisScalar, bool TisScalar> + template <class MatrixA, class MatrixB, class TransformationMatrix1, + class TransformationMatrix2, bool AisScalar, bool BisScalar, + bool T1isScalar, bool T2isScalar> struct TransformMatrixHelper { - static void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1, + static void addTransformedMatrix(MatrixA& A, + const TransformationMatrix1& T1, const MatrixB& B, - const TransformationMatrix& T2) { + const TransformationMatrix2& T2) { for (size_t i = 0; i < A.N(); ++i) { for (auto jIt = A[i].begin(); jIt != A[i].end(); ++jIt) { for (size_t k = 0; k < B.N(); ++k) { @@ -72,7 +74,8 @@ namespace MatrixVector { template <class K1, class K2, class K3, int n, int m> struct TransformMatrixHelper< Dune::FieldMatrix<K1, m, m>, Dune::FieldMatrix<K3, n, n>, - Dune::FieldMatrix<K2, n, m>, false, false, false> { + Dune::FieldMatrix<K2, n, m>, Dune::FieldMatrix<K2, n, m>, false, false, + false, false> { typedef Dune::FieldMatrix<K1, m, m> MatrixA; typedef Dune::FieldMatrix<K3, n, n> MatrixB; typedef Dune::FieldMatrix<K2, n, m> TransformationMatrix; @@ -111,25 +114,30 @@ namespace MatrixVector { // } }; - template <class MatrixA, class MatrixB, class ScalarTransform, bool AisScalar, - bool BisScalar> - struct TransformMatrixHelper<MatrixA, MatrixB, ScalarTransform, AisScalar, - BisScalar, true> { - static void addTransformedMatrix(MatrixA& A, const ScalarTransform& T1, + template <class MatrixA, class MatrixB, class ScalarTransform1, + class ScalarTransform2, bool AisScalar, bool BisScalar> + struct TransformMatrixHelper<MatrixA, MatrixB, ScalarTransform1, + ScalarTransform2, AisScalar, BisScalar, true, + true> { + static void addTransformedMatrix(MatrixA& A, const ScalarTransform1& T1, const MatrixB& B, - const ScalarTransform& T2) { + const ScalarTransform2& T2) { addProduct(A, T1 * T2, B); } }; - template <class MatrixA, class ScalarB, class TransformationMatrix, - bool AisScalar, bool TisScalar> - struct TransformMatrixHelper<MatrixA, ScalarB, TransformationMatrix, - AisScalar, true, TisScalar> { - static void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1, + template <class MatrixA, class ScalarB, class TransformationMatrix1, + class TransformationMatrix2, bool AisScalar, bool T1isScalar, + bool T2isScalar> + struct TransformMatrixHelper<MatrixA, ScalarB, TransformationMatrix1, + TransformationMatrix2, AisScalar, true, + T1isScalar, T2isScalar> { + static void addTransformedMatrix(MatrixA& A, + const TransformationMatrix1& T1, const ScalarB& B, - const TransformationMatrix& T2) { - for (size_t k = 0; k < TransformationMatrix::rows; ++k) { + const TransformationMatrix2& T2) { + assert(TransformationMatrix1::rows == TransformationMatrix2::rows); + for (size_t k = 0; k < TransformationMatrix1::rows; ++k) { for (auto Skj = T2[k].begin(); Skj != T2[k].end(); ++Skj) { for (auto Tki = T1[k].begin(); Tki != T1[k].end(); Tki++) if (A.exists(Tki.index(), Skj.index())) @@ -139,41 +147,44 @@ namespace MatrixVector { } }; - template <class FieldType, int n, class ScalarB, class TransformationMatrix, - bool AisScalar, bool TisScalar> - struct TransformMatrixHelper<Dune::ScaledIdentityMatrix<FieldType, n>, - ScalarB, TransformationMatrix, AisScalar, true, - TisScalar> { + template <class FieldType, int n, class ScalarB, class TransformationMatrix1, + class TransformationMatrix2, bool AisScalar, bool T1isScalar, + bool T2isScalar> + struct TransformMatrixHelper< + Dune::ScaledIdentityMatrix<FieldType, n>, ScalarB, TransformationMatrix1, + TransformationMatrix2, AisScalar, true, T1isScalar, T2isScalar> { typedef Dune::ScaledIdentityMatrix<FieldType, n> MatrixA; - static void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1, + static void addTransformedMatrix(MatrixA& A, + const TransformationMatrix1& T1, const ScalarB& B, - const TransformationMatrix& T2) { - TransformMatrixHelper<FieldType, ScalarB, - typename TransformationMatrix::field_type, true, - true, true>::addTransformedMatrix(A.scalar(), - T1.scalar(), B, - T2.scalar()); + const TransformationMatrix2& T2) { + TransformMatrixHelper< + FieldType, ScalarB, typename TransformationMatrix1::field_type, + typename TransformationMatrix2::field_type, true, true, true, + true>::addTransformedMatrix(A.scalar(), T1.scalar(), B, T2.scalar()); } }; - template <class ScalarA, class ScalarB, class ScalarTransform> - struct TransformMatrixHelper<ScalarA, ScalarB, ScalarTransform, true, true, - true> { - static void addTransformedMatrix(ScalarA& A, const ScalarTransform& T1, + template <class ScalarA, class ScalarB, class ScalarTransform1, + class ScalarTransform2> + struct TransformMatrixHelper<ScalarA, ScalarB, ScalarTransform1, + ScalarTransform2, true, true, true, true> { + static void addTransformedMatrix(ScalarA& A, const ScalarTransform1& T1, const ScalarB& B, - const ScalarTransform& T2) { + const ScalarTransform2& T2) { A += T1 * B * T2; } }; template <class MatrixA, typename FieldType, int n, - class TransformationMatrix> + class TransformationMatrix1, class TransformationMatrix2> struct TransformMatrixHelper<MatrixA, Dune::DiagonalMatrix<FieldType, n>, - TransformationMatrix, false, false, false> { + TransformationMatrix1, TransformationMatrix2, + false, false, false, false> { static void addTransformedMatrix( - MatrixA& A, const TransformationMatrix& T1, + MatrixA& A, const TransformationMatrix1& T1, const Dune::DiagonalMatrix<FieldType, n>& B, - const TransformationMatrix& T2) { + const TransformationMatrix2& T2) { for (size_t k = 0; k < n; ++k) { for (auto Skj = T2[k].begin(); Skj != T2[k].end(); ++Skj) { for (auto Tki = T1[k].begin(); Tki != T1[k].end(); Tki++) @@ -186,28 +197,30 @@ namespace MatrixVector { }; template <class MatrixA, typename FieldType, int n, - class TransformationMatrix> - struct TransformMatrixHelper<MatrixA, - Dune::ScaledIdentityMatrix<FieldType, n>, - TransformationMatrix, false, false, false> { + class TransformationMatrix1, class TransformationMatrix2> + struct TransformMatrixHelper< + MatrixA, Dune::ScaledIdentityMatrix<FieldType, n>, TransformationMatrix1, + TransformationMatrix2, false, false, false, false> { static void addTransformedMatrix( - MatrixA& A, const TransformationMatrix& T1, + MatrixA& A, const TransformationMatrix1& T1, const Dune::ScaledIdentityMatrix<FieldType, n>& B, - const TransformationMatrix& T2) { - TransformMatrixHelper<MatrixA, FieldType, TransformationMatrix, false, - true, false>::addTransformedMatrix(A, T1, - B.scalar(), T2); + const TransformationMatrix2& T2) { + TransformMatrixHelper<MatrixA, FieldType, TransformationMatrix1, + TransformationMatrix2, false, true, false, + false>::addTransformedMatrix(A, T1, B.scalar(), T2); } }; - template <typename FieldType, int n, class TransformationMatrix> + template <typename FieldType, int n, class TransformationMatrix1, + class TransformationMatrix2> struct TransformMatrixHelper<Dune::DiagonalMatrix<FieldType, n>, Dune::DiagonalMatrix<FieldType, n>, - TransformationMatrix, false, false, false> { + TransformationMatrix1, TransformationMatrix2, + false, false, false, false> { static void addTransformedMatrix( - Dune::DiagonalMatrix<FieldType, n>& A, const TransformationMatrix& T1, + Dune::DiagonalMatrix<FieldType, n>& A, const TransformationMatrix1& T1, const Dune::DiagonalMatrix<FieldType, n>& B, - const TransformationMatrix& T2) { + const TransformationMatrix2& T2) { for (size_t k = 0; k < n; k++) { for (auto Tki = T1[k].begin(); Tki != T1[k].end(); ++Tki) A.diagonal(Tki.index()) += @@ -216,15 +229,17 @@ namespace MatrixVector { } }; - template <typename FieldType, int n, class TransformationMatrix> + template <typename FieldType, int n, class TransformationMatrix1, + class TransformationMatrix2> struct TransformMatrixHelper<Dune::ScaledIdentityMatrix<FieldType, n>, Dune::ScaledIdentityMatrix<FieldType, n>, - TransformationMatrix, false, false, false> { + TransformationMatrix1, TransformationMatrix2, + false, false, false, false> { static void addTransformedMatrix( Dune::ScaledIdentityMatrix<FieldType, n>& A, - const TransformationMatrix& T1, + const TransformationMatrix1& T1, const Dune::ScaledIdentityMatrix<FieldType, n>& B, - const TransformationMatrix& T2) { + const TransformationMatrix2& T2) { for (size_t k = 0; k < n; k++) A.scalar() += T1[k][0] * B.scalar() * T2[k][0]; } @@ -232,50 +247,54 @@ namespace MatrixVector { template <typename FieldType, int n> struct TransformMatrixHelper<Dune::ScaledIdentityMatrix<FieldType, n>, + Dune::ScaledIdentityMatrix<FieldType, n>, Dune::ScaledIdentityMatrix<FieldType, n>, Dune::ScaledIdentityMatrix<FieldType, n>, false, - false, false> { + false, false, false> { static void addTransformedMatrix( Dune::ScaledIdentityMatrix<FieldType, n>& A, const Dune::ScaledIdentityMatrix<FieldType, n>& T1, const Dune::ScaledIdentityMatrix<FieldType, n>& B, const Dune::ScaledIdentityMatrix<FieldType, n>& T2) { - TransformMatrixHelper<FieldType, FieldType, FieldType, true, true, - true>::addTransformedMatrix(A.scalar(), T1.scalar(), - B.scalar(), - T2.scalar()); + TransformMatrixHelper< + FieldType, FieldType, FieldType, FieldType, true, true, true, + true>::addTransformedMatrix(A.scalar(), T1.scalar(), B.scalar(), + T2.scalar()); } }; - template <class MatrixA, class MatrixB, class TransformationMatrix> - void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1, - const MatrixB& B, const TransformationMatrix& T2) { + template <class MatrixA, class MatrixB, class TransformationMatrix1, + class TransformationMatrix2> + void addTransformedMatrix(MatrixA& A, const TransformationMatrix1& T1, + const MatrixB& B, const TransformationMatrix2& T2) { TransformMatrixHelper< - MatrixA, MatrixB, TransformationMatrix, ScalarTraits<MatrixA>::isScalar, - ScalarTraits<MatrixB>::isScalar, - ScalarTraits<TransformationMatrix>::isScalar>::addTransformedMatrix(A, - T1, - B, - T2); + MatrixA, MatrixB, TransformationMatrix1, TransformationMatrix2, + ScalarTraits<MatrixA>::isScalar, ScalarTraits<MatrixB>::isScalar, + ScalarTraits<TransformationMatrix1>::isScalar, + ScalarTraits< + TransformationMatrix2>::isScalar>::addTransformedMatrix(A, T1, B, + T2); } - template <class MatrixA, class MatrixB, class TransformationMatrix> - void transformMatrix(MatrixA& A, const TransformationMatrix& T1, - const MatrixB& B, const TransformationMatrix& T2) { + template <class MatrixA, class MatrixB, class TransformationMatrix1, + class TransformationMatrix2> + void transformMatrix(MatrixA& A, const TransformationMatrix1& T1, + const MatrixB& B, const TransformationMatrix2& T2) { A = 0; TransformMatrixHelper< - MatrixA, MatrixB, TransformationMatrix, ScalarTraits<MatrixA>::isScalar, - ScalarTraits<MatrixB>::isScalar, - ScalarTraits<TransformationMatrix>::isScalar>::addTransformedMatrix(A, - T1, - B, - T2); + MatrixA, MatrixB, TransformationMatrix1, TransformationMatrix2, + ScalarTraits<MatrixA>::isScalar, ScalarTraits<MatrixB>::isScalar, + ScalarTraits<TransformationMatrix1>::isScalar, + ScalarTraits< + TransformationMatrix2>::isScalar>::addTransformedMatrix(A, T1, B, + T2); } - template <class MatrixBlockA, class MatrixB, class TransformationMatrix> + template <class MatrixBlockA, class MatrixB, class TransformationMatrix1, + class TransformationMatrix2> static void transformMatrix(Dune::BCRSMatrix<MatrixBlockA>& A, - const TransformationMatrix& T1, const MatrixB& B, - const TransformationMatrix& T2) { + const TransformationMatrix1& T1, const MatrixB& B, + const TransformationMatrix2& T2) { transformMatrixPattern(A, T1, B, T2); A = 0.0; for (size_t k = 0; k < B.N(); ++k) { @@ -294,11 +313,12 @@ namespace MatrixVector { } } - template <class MatrixBlockA, class MatrixB, class TransformationMatrix> + template <class MatrixBlockA, class MatrixB, class TransformationMatrix1, + class TransformationMatrix2> static void transformMatrixPattern(Dune::BCRSMatrix<MatrixBlockA>& A, - const TransformationMatrix& T1, + const TransformationMatrix1& T1, const MatrixB& B, - const TransformationMatrix& T2) { + const TransformationMatrix2& T2) { Dune::MatrixIndexSet indices(T1.M(), T2.M()); for (size_t k = 0; k < B.N(); ++k) { for (auto BklIt = B[k].begin(); BklIt != B[k].end(); ++BklIt) { -- GitLab