From 33df8947c26aa596f0c46a3853f770f72856e2d4 Mon Sep 17 00:00:00 2001
From: Elias Pipping <elias.pipping@fu-berlin.de>
Date: Tue, 16 Apr 2013 16:10:57 +0000
Subject: [PATCH] arithmetic: Sync with fufem

[[Imported from SVN: r8500]]
---
 dune/solvers/common/arithmetic.hh | 98 +++++++++++++++++++++++++++++--
 1 file changed, 93 insertions(+), 5 deletions(-)

diff --git a/dune/solvers/common/arithmetic.hh b/dune/solvers/common/arithmetic.hh
index 16fe0657..664bb5c5 100644
--- a/dune/solvers/common/arithmetic.hh
+++ b/dune/solvers/common/arithmetic.hh
@@ -2,10 +2,11 @@
 #define ARITHMETIC_HH
 
 
-#include <dune/common/diagonalmatrix.hh>
 #include <dune/common/fvector.hh>
 #include <dune/common/fmatrix.hh>
+#include <dune/common/diagonalmatrix.hh>
 #include <dune/istl/scaledidmatrix.hh>
+#include <dune/common/typetraits.hh>
 
 /** \brief Namespace containing helper classes and functions for arithmetic operations
  *
@@ -113,6 +114,16 @@ namespace Arithmetic
         }
     };
 
+    // Internal helper class for scaled product operations (i.e., b is always a scalar)
+    template<class A, class B, class C, class D, bool AisScalar, bool CisScalar, bool DisScalar>
+    struct ScaledProductHelper
+    {
+        static void addProduct(A& a, const B& b, const C& c, const D& d)
+        {
+            c.usmv(b, d, a);
+        }
+    };
+
     template<class T, int n, int m, int k>
     struct ProductHelper<Dune::FieldMatrix<T,n,k>, Dune::FieldMatrix<T,n,m>, Dune::FieldMatrix<T,m,k>, false, false, false>
     {
@@ -132,6 +143,23 @@ namespace Arithmetic
         }
     };
 
+    template<class T, int n, int m, int k>
+    struct ScaledProductHelper<Dune::FieldMatrix<T,n,k>, T, Dune::FieldMatrix<T,n,m>, Dune::FieldMatrix<T,m,k>, false, false, false>
+    {
+        typedef Dune::FieldMatrix<T,n,k> A;
+        typedef Dune::FieldMatrix<T,n,m> C;
+        typedef Dune::FieldMatrix<T,m,k> D;
+        static void addProduct(A& a, const T& b, const C& c, const D& d)
+        {
+            for (size_t row = 0; row < n; ++row) {
+                for (size_t col = 0 ; col < k; ++col) {
+                    for (size_t i = 0; i < m; ++i)
+                        a[row][col] += b * c[row][i]*d[i][col];
+                }
+            }
+        }
+    };
+
     template<class T, int n>
     struct ProductHelper<Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false>
     {
@@ -145,6 +173,19 @@ namespace Arithmetic
         }
     };
 
+    template<class T, int n>
+    struct ScaledProductHelper<Dune::DiagonalMatrix<T,n>, T, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false>
+    {
+        typedef Dune::DiagonalMatrix<T,n> A;
+        typedef A C;
+        typedef C D;
+        static void addProduct(A& a, const T& b, const C& c, const D& d)
+        {
+            for (size_t i=0; i<n; ++i)
+                a.diagonal(i) += b * c.diagonal(i)*d.diagonal(i);
+        }
+    };
+
     template<class T, int n>
     struct ProductHelper<Dune::FieldMatrix<T,n,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false>
     {
@@ -158,7 +199,20 @@ namespace Arithmetic
         }
     };
 
-    /** \brief Specialization for b being a scalar type and (a not a matrix type)
+    template<class T, int n>
+    struct ScaledProductHelper<Dune::FieldMatrix<T,n,n>, T, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false>
+    {
+        typedef Dune::FieldMatrix<T,n,n> A;
+        typedef Dune::DiagonalMatrix<T,n> C;
+        typedef C D;
+        static void addProduct(A& a, const T& b, const C& c, const D& d)
+        {
+            for (size_t i=0; i<n; ++i)
+                a[i][i] += b * c.diagonal(i)*d.diagonal(i);
+        }
+    };
+
+    /** \brief Specialization for b being a scalar type
       */
     template<class A, class B, class C, bool AisScalar, bool CisScalar>
     struct ProductHelper<A, B, C, AisScalar, true, CisScalar>
@@ -169,6 +223,17 @@ namespace Arithmetic
         }
     };
 
+    /** \brief Specialization for c being a scalar type
+     */
+    template<class A, class B, class C, class D, bool AisScalar, bool DisScalar>
+    struct ScaledProductHelper<A, B, C, D, AisScalar, true, DisScalar>
+    {
+        static void addProduct(A& a, const B& b, const C& c, const D& d)
+        {
+            a.axpy(b * c, d);
+        }
+    };
+
     template<class A, class B, class C>
     struct ProductHelper<A, B, C, true, true, true>
     {
@@ -178,13 +243,20 @@ namespace Arithmetic
         }
     };
 
-
+    template<class A, class B, class C, class D>
+    struct ScaledProductHelper<A, B, C, D, true, true, true>
+    {
+        static void addProduct(A& a, const B& b, const C& c, const D& d)
+        {
+            a += b*c*d;
+        }
+    };
 
     /** \brief Add a product to some matrix or vector
      *
      * This function computes a+=b*c.
      *
-     * This functions should tolerate all meaningful
+     * This function should tolerate all meaningful
      * combinations of scalars, vectors, and matrices.
      *
      * a,b,c could be matrices with appropriate
@@ -197,7 +269,23 @@ namespace Arithmetic
         ProductHelper<A,B,C,ScalarTraits<A>::isScalar, ScalarTraits<B>::isScalar, ScalarTraits<C>::isScalar>::addProduct(a,b,c);
     }
 
-
+    /** \brief Add a scaled product to some matrix or vector
+     *
+     * This function computes a+=b*c.
+     *
+     * This function should tolerate all meaningful
+     * combinations of scalars, vectors, and matrices.
+     *
+     * a,c,d could be matrices with appropriate dimensions. But b must
+     * (currently) and c can also always be a scalar represented by a
+     * 1-dim vector or a 1 by 1 matrix.
+     */
+    template<class A, class B, class C, class D>
+    typename Dune::enable_if<ScalarTraits<B>::isScalar, void>::type
+    addProduct(A& a, const B& b, const C& c, const D& d)
+    {
+        ScaledProductHelper<A,B,C,D,ScalarTraits<A>::isScalar, ScalarTraits<C>::isScalar, ScalarTraits<D>::isScalar>::addProduct(a,b,c,d);
+    }
 };
 
 #endif
-- 
GitLab