From 24803525b6bc0104837bab7aabc20b663319137d Mon Sep 17 00:00:00 2001
From: Max Kahnt <max.kahnt@fu-berlin.de>
Date: Wed, 27 Jul 2016 12:38:11 +0200
Subject: [PATCH] Allow different types for transformations T1 and T2.

---
 dune/matrix-vector/transformmatrix.hh | 190 ++++++++++++++------------
 1 file changed, 105 insertions(+), 85 deletions(-)

diff --git a/dune/matrix-vector/transformmatrix.hh b/dune/matrix-vector/transformmatrix.hh
index 328f146..e2679ae 100644
--- a/dune/matrix-vector/transformmatrix.hh
+++ b/dune/matrix-vector/transformmatrix.hh
@@ -14,12 +14,14 @@ namespace MatrixVector {
 
   // add transformed matrix A += T1^t*B*T2
   // ******************************************************
-  template <class MatrixA, class MatrixB, class TransformationMatrix,
-            bool AisScalar, bool BisScalar, bool TisScalar>
+  template <class MatrixA, class MatrixB, class TransformationMatrix1,
+            class TransformationMatrix2, bool AisScalar, bool BisScalar,
+            bool T1isScalar, bool T2isScalar>
   struct TransformMatrixHelper {
-    static void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1,
+    static void addTransformedMatrix(MatrixA& A,
+                                     const TransformationMatrix1& T1,
                                      const MatrixB& B,
-                                     const TransformationMatrix& T2) {
+                                     const TransformationMatrix2& T2) {
       for (size_t i = 0; i < A.N(); ++i) {
         for (auto jIt = A[i].begin(); jIt != A[i].end(); ++jIt) {
           for (size_t k = 0; k < B.N(); ++k) {
@@ -72,7 +74,8 @@ namespace MatrixVector {
   template <class K1, class K2, class K3, int n, int m>
   struct TransformMatrixHelper<
       Dune::FieldMatrix<K1, m, m>, Dune::FieldMatrix<K3, n, n>,
-      Dune::FieldMatrix<K2, n, m>, false, false, false> {
+      Dune::FieldMatrix<K2, n, m>, Dune::FieldMatrix<K2, n, m>, false, false,
+      false, false> {
     typedef Dune::FieldMatrix<K1, m, m> MatrixA;
     typedef Dune::FieldMatrix<K3, n, n> MatrixB;
     typedef Dune::FieldMatrix<K2, n, m> TransformationMatrix;
@@ -111,25 +114,30 @@ namespace MatrixVector {
     // }
   };
 
-  template <class MatrixA, class MatrixB, class ScalarTransform, bool AisScalar,
-            bool BisScalar>
-  struct TransformMatrixHelper<MatrixA, MatrixB, ScalarTransform, AisScalar,
-                               BisScalar, true> {
-    static void addTransformedMatrix(MatrixA& A, const ScalarTransform& T1,
+  template <class MatrixA, class MatrixB, class ScalarTransform1,
+            class ScalarTransform2, bool AisScalar, bool BisScalar>
+  struct TransformMatrixHelper<MatrixA, MatrixB, ScalarTransform1,
+                               ScalarTransform2, AisScalar, BisScalar, true,
+                               true> {
+    static void addTransformedMatrix(MatrixA& A, const ScalarTransform1& T1,
                                      const MatrixB& B,
-                                     const ScalarTransform& T2) {
+                                     const ScalarTransform2& T2) {
       addProduct(A, T1 * T2, B);
     }
   };
 
-  template <class MatrixA, class ScalarB, class TransformationMatrix,
-            bool AisScalar, bool TisScalar>
-  struct TransformMatrixHelper<MatrixA, ScalarB, TransformationMatrix,
-                               AisScalar, true, TisScalar> {
-    static void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1,
+  template <class MatrixA, class ScalarB, class TransformationMatrix1,
+            class TransformationMatrix2, bool AisScalar, bool T1isScalar,
+            bool T2isScalar>
+  struct TransformMatrixHelper<MatrixA, ScalarB, TransformationMatrix1,
+                               TransformationMatrix2, AisScalar, true,
+                               T1isScalar, T2isScalar> {
+    static void addTransformedMatrix(MatrixA& A,
+                                     const TransformationMatrix1& T1,
                                      const ScalarB& B,
-                                     const TransformationMatrix& T2) {
-      for (size_t k = 0; k < TransformationMatrix::rows; ++k) {
+                                     const TransformationMatrix2& T2) {
+      assert(TransformationMatrix1::rows == TransformationMatrix2::rows);
+      for (size_t k = 0; k < TransformationMatrix1::rows; ++k) {
         for (auto Skj = T2[k].begin(); Skj != T2[k].end(); ++Skj) {
           for (auto Tki = T1[k].begin(); Tki != T1[k].end(); Tki++)
             if (A.exists(Tki.index(), Skj.index()))
@@ -139,41 +147,44 @@ namespace MatrixVector {
     }
   };
 
-  template <class FieldType, int n, class ScalarB, class TransformationMatrix,
-            bool AisScalar, bool TisScalar>
-  struct TransformMatrixHelper<Dune::ScaledIdentityMatrix<FieldType, n>,
-                               ScalarB, TransformationMatrix, AisScalar, true,
-                               TisScalar> {
+  template <class FieldType, int n, class ScalarB, class TransformationMatrix1,
+            class TransformationMatrix2, bool AisScalar, bool T1isScalar,
+            bool T2isScalar>
+  struct TransformMatrixHelper<
+      Dune::ScaledIdentityMatrix<FieldType, n>, ScalarB, TransformationMatrix1,
+      TransformationMatrix2, AisScalar, true, T1isScalar, T2isScalar> {
     typedef Dune::ScaledIdentityMatrix<FieldType, n> MatrixA;
-    static void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1,
+    static void addTransformedMatrix(MatrixA& A,
+                                     const TransformationMatrix1& T1,
                                      const ScalarB& B,
-                                     const TransformationMatrix& T2) {
-      TransformMatrixHelper<FieldType, ScalarB,
-                            typename TransformationMatrix::field_type, true,
-                            true, true>::addTransformedMatrix(A.scalar(),
-                                                              T1.scalar(), B,
-                                                              T2.scalar());
+                                     const TransformationMatrix2& T2) {
+      TransformMatrixHelper<
+          FieldType, ScalarB, typename TransformationMatrix1::field_type,
+          typename TransformationMatrix2::field_type, true, true, true,
+          true>::addTransformedMatrix(A.scalar(), T1.scalar(), B, T2.scalar());
     }
   };
 
-  template <class ScalarA, class ScalarB, class ScalarTransform>
-  struct TransformMatrixHelper<ScalarA, ScalarB, ScalarTransform, true, true,
-                               true> {
-    static void addTransformedMatrix(ScalarA& A, const ScalarTransform& T1,
+  template <class ScalarA, class ScalarB, class ScalarTransform1,
+            class ScalarTransform2>
+  struct TransformMatrixHelper<ScalarA, ScalarB, ScalarTransform1,
+                               ScalarTransform2, true, true, true, true> {
+    static void addTransformedMatrix(ScalarA& A, const ScalarTransform1& T1,
                                      const ScalarB& B,
-                                     const ScalarTransform& T2) {
+                                     const ScalarTransform2& T2) {
       A += T1 * B * T2;
     }
   };
 
   template <class MatrixA, typename FieldType, int n,
-            class TransformationMatrix>
+            class TransformationMatrix1, class TransformationMatrix2>
   struct TransformMatrixHelper<MatrixA, Dune::DiagonalMatrix<FieldType, n>,
-                               TransformationMatrix, false, false, false> {
+                               TransformationMatrix1, TransformationMatrix2,
+                               false, false, false, false> {
     static void addTransformedMatrix(
-        MatrixA& A, const TransformationMatrix& T1,
+        MatrixA& A, const TransformationMatrix1& T1,
         const Dune::DiagonalMatrix<FieldType, n>& B,
-        const TransformationMatrix& T2) {
+        const TransformationMatrix2& T2) {
       for (size_t k = 0; k < n; ++k) {
         for (auto Skj = T2[k].begin(); Skj != T2[k].end(); ++Skj) {
           for (auto Tki = T1[k].begin(); Tki != T1[k].end(); Tki++)
@@ -186,28 +197,30 @@ namespace MatrixVector {
   };
 
   template <class MatrixA, typename FieldType, int n,
-            class TransformationMatrix>
-  struct TransformMatrixHelper<MatrixA,
-                               Dune::ScaledIdentityMatrix<FieldType, n>,
-                               TransformationMatrix, false, false, false> {
+            class TransformationMatrix1, class TransformationMatrix2>
+  struct TransformMatrixHelper<
+      MatrixA, Dune::ScaledIdentityMatrix<FieldType, n>, TransformationMatrix1,
+      TransformationMatrix2, false, false, false, false> {
     static void addTransformedMatrix(
-        MatrixA& A, const TransformationMatrix& T1,
+        MatrixA& A, const TransformationMatrix1& T1,
         const Dune::ScaledIdentityMatrix<FieldType, n>& B,
-        const TransformationMatrix& T2) {
-      TransformMatrixHelper<MatrixA, FieldType, TransformationMatrix, false,
-                            true, false>::addTransformedMatrix(A, T1,
-                                                               B.scalar(), T2);
+        const TransformationMatrix2& T2) {
+      TransformMatrixHelper<MatrixA, FieldType, TransformationMatrix1,
+                            TransformationMatrix2, false, true, false,
+                            false>::addTransformedMatrix(A, T1, B.scalar(), T2);
     }
   };
 
-  template <typename FieldType, int n, class TransformationMatrix>
+  template <typename FieldType, int n, class TransformationMatrix1,
+            class TransformationMatrix2>
   struct TransformMatrixHelper<Dune::DiagonalMatrix<FieldType, n>,
                                Dune::DiagonalMatrix<FieldType, n>,
-                               TransformationMatrix, false, false, false> {
+                               TransformationMatrix1, TransformationMatrix2,
+                               false, false, false, false> {
     static void addTransformedMatrix(
-        Dune::DiagonalMatrix<FieldType, n>& A, const TransformationMatrix& T1,
+        Dune::DiagonalMatrix<FieldType, n>& A, const TransformationMatrix1& T1,
         const Dune::DiagonalMatrix<FieldType, n>& B,
-        const TransformationMatrix& T2) {
+        const TransformationMatrix2& T2) {
       for (size_t k = 0; k < n; k++) {
         for (auto Tki = T1[k].begin(); Tki != T1[k].end(); ++Tki)
           A.diagonal(Tki.index()) +=
@@ -216,15 +229,17 @@ namespace MatrixVector {
     }
   };
 
-  template <typename FieldType, int n, class TransformationMatrix>
+  template <typename FieldType, int n, class TransformationMatrix1,
+            class TransformationMatrix2>
   struct TransformMatrixHelper<Dune::ScaledIdentityMatrix<FieldType, n>,
                                Dune::ScaledIdentityMatrix<FieldType, n>,
-                               TransformationMatrix, false, false, false> {
+                               TransformationMatrix1, TransformationMatrix2,
+                               false, false, false, false> {
     static void addTransformedMatrix(
         Dune::ScaledIdentityMatrix<FieldType, n>& A,
-        const TransformationMatrix& T1,
+        const TransformationMatrix1& T1,
         const Dune::ScaledIdentityMatrix<FieldType, n>& B,
-        const TransformationMatrix& T2) {
+        const TransformationMatrix2& T2) {
       for (size_t k = 0; k < n; k++)
         A.scalar() += T1[k][0] * B.scalar() * T2[k][0];
     }
@@ -232,50 +247,54 @@ namespace MatrixVector {
 
   template <typename FieldType, int n>
   struct TransformMatrixHelper<Dune::ScaledIdentityMatrix<FieldType, n>,
+                               Dune::ScaledIdentityMatrix<FieldType, n>,
                                Dune::ScaledIdentityMatrix<FieldType, n>,
                                Dune::ScaledIdentityMatrix<FieldType, n>, false,
-                               false, false> {
+                               false, false, false> {
     static void addTransformedMatrix(
         Dune::ScaledIdentityMatrix<FieldType, n>& A,
         const Dune::ScaledIdentityMatrix<FieldType, n>& T1,
         const Dune::ScaledIdentityMatrix<FieldType, n>& B,
         const Dune::ScaledIdentityMatrix<FieldType, n>& T2) {
-      TransformMatrixHelper<FieldType, FieldType, FieldType, true, true,
-                            true>::addTransformedMatrix(A.scalar(), T1.scalar(),
-                                                        B.scalar(),
-                                                        T2.scalar());
+      TransformMatrixHelper<
+          FieldType, FieldType, FieldType, FieldType, true, true, true,
+          true>::addTransformedMatrix(A.scalar(), T1.scalar(), B.scalar(),
+                                      T2.scalar());
     }
   };
 
-  template <class MatrixA, class MatrixB, class TransformationMatrix>
-  void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1,
-                            const MatrixB& B, const TransformationMatrix& T2) {
+  template <class MatrixA, class MatrixB, class TransformationMatrix1,
+            class TransformationMatrix2>
+  void addTransformedMatrix(MatrixA& A, const TransformationMatrix1& T1,
+                            const MatrixB& B, const TransformationMatrix2& T2) {
     TransformMatrixHelper<
-        MatrixA, MatrixB, TransformationMatrix, ScalarTraits<MatrixA>::isScalar,
-        ScalarTraits<MatrixB>::isScalar,
-        ScalarTraits<TransformationMatrix>::isScalar>::addTransformedMatrix(A,
-                                                                            T1,
-                                                                            B,
-                                                                            T2);
+        MatrixA, MatrixB, TransformationMatrix1, TransformationMatrix2,
+        ScalarTraits<MatrixA>::isScalar, ScalarTraits<MatrixB>::isScalar,
+        ScalarTraits<TransformationMatrix1>::isScalar,
+        ScalarTraits<
+            TransformationMatrix2>::isScalar>::addTransformedMatrix(A, T1, B,
+                                                                    T2);
   }
 
-  template <class MatrixA, class MatrixB, class TransformationMatrix>
-  void transformMatrix(MatrixA& A, const TransformationMatrix& T1,
-                       const MatrixB& B, const TransformationMatrix& T2) {
+  template <class MatrixA, class MatrixB, class TransformationMatrix1,
+            class TransformationMatrix2>
+  void transformMatrix(MatrixA& A, const TransformationMatrix1& T1,
+                       const MatrixB& B, const TransformationMatrix2& T2) {
     A = 0;
     TransformMatrixHelper<
-        MatrixA, MatrixB, TransformationMatrix, ScalarTraits<MatrixA>::isScalar,
-        ScalarTraits<MatrixB>::isScalar,
-        ScalarTraits<TransformationMatrix>::isScalar>::addTransformedMatrix(A,
-                                                                            T1,
-                                                                            B,
-                                                                            T2);
+        MatrixA, MatrixB, TransformationMatrix1, TransformationMatrix2,
+        ScalarTraits<MatrixA>::isScalar, ScalarTraits<MatrixB>::isScalar,
+        ScalarTraits<TransformationMatrix1>::isScalar,
+        ScalarTraits<
+            TransformationMatrix2>::isScalar>::addTransformedMatrix(A, T1, B,
+                                                                    T2);
   }
 
-  template <class MatrixBlockA, class MatrixB, class TransformationMatrix>
+  template <class MatrixBlockA, class MatrixB, class TransformationMatrix1,
+            class TransformationMatrix2>
   static void transformMatrix(Dune::BCRSMatrix<MatrixBlockA>& A,
-                              const TransformationMatrix& T1, const MatrixB& B,
-                              const TransformationMatrix& T2) {
+                              const TransformationMatrix1& T1, const MatrixB& B,
+                              const TransformationMatrix2& T2) {
     transformMatrixPattern(A, T1, B, T2);
     A = 0.0;
     for (size_t k = 0; k < B.N(); ++k) {
@@ -294,11 +313,12 @@ namespace MatrixVector {
     }
   }
 
-  template <class MatrixBlockA, class MatrixB, class TransformationMatrix>
+  template <class MatrixBlockA, class MatrixB, class TransformationMatrix1,
+            class TransformationMatrix2>
   static void transformMatrixPattern(Dune::BCRSMatrix<MatrixBlockA>& A,
-                                     const TransformationMatrix& T1,
+                                     const TransformationMatrix1& T1,
                                      const MatrixB& B,
-                                     const TransformationMatrix& T2) {
+                                     const TransformationMatrix2& T2) {
     Dune::MatrixIndexSet indices(T1.M(), T2.M());
     for (size_t k = 0; k < B.N(); ++k) {
       for (auto BklIt = B[k].begin(); BklIt != B[k].end(); ++BklIt) {
-- 
GitLab