From 33df8947c26aa596f0c46a3853f770f72856e2d4 Mon Sep 17 00:00:00 2001 From: Elias Pipping <elias.pipping@fu-berlin.de> Date: Tue, 16 Apr 2013 16:10:57 +0000 Subject: [PATCH] arithmetic: Sync with fufem [[Imported from SVN: r8500]] --- dune/solvers/common/arithmetic.hh | 98 +++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 5 deletions(-) diff --git a/dune/solvers/common/arithmetic.hh b/dune/solvers/common/arithmetic.hh index 16fe0657..664bb5c5 100644 --- a/dune/solvers/common/arithmetic.hh +++ b/dune/solvers/common/arithmetic.hh @@ -2,10 +2,11 @@ #define ARITHMETIC_HH -#include <dune/common/diagonalmatrix.hh> #include <dune/common/fvector.hh> #include <dune/common/fmatrix.hh> +#include <dune/common/diagonalmatrix.hh> #include <dune/istl/scaledidmatrix.hh> +#include <dune/common/typetraits.hh> /** \brief Namespace containing helper classes and functions for arithmetic operations * @@ -113,6 +114,16 @@ namespace Arithmetic } }; + // Internal helper class for scaled product operations (i.e., b is always a scalar) + template<class A, class B, class C, class D, bool AisScalar, bool CisScalar, bool DisScalar> + struct ScaledProductHelper + { + static void addProduct(A& a, const B& b, const C& c, const D& d) + { + c.usmv(b, d, a); + } + }; + template<class T, int n, int m, int k> struct ProductHelper<Dune::FieldMatrix<T,n,k>, Dune::FieldMatrix<T,n,m>, Dune::FieldMatrix<T,m,k>, false, false, false> { @@ -132,6 +143,23 @@ namespace Arithmetic } }; + template<class T, int n, int m, int k> + struct ScaledProductHelper<Dune::FieldMatrix<T,n,k>, T, Dune::FieldMatrix<T,n,m>, Dune::FieldMatrix<T,m,k>, false, false, false> + { + typedef Dune::FieldMatrix<T,n,k> A; + typedef Dune::FieldMatrix<T,n,m> C; + typedef Dune::FieldMatrix<T,m,k> D; + static void addProduct(A& a, const T& b, const C& c, const D& d) + { + for (size_t row = 0; row < n; ++row) { + for (size_t col = 0 ; col < k; ++col) { + for (size_t i = 0; i < m; ++i) + a[row][col] += b * c[row][i]*d[i][col]; + } + } + } + }; + template<class T, int n> struct ProductHelper<Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false> { @@ -145,6 +173,19 @@ namespace Arithmetic } }; + template<class T, int n> + struct ScaledProductHelper<Dune::DiagonalMatrix<T,n>, T, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false> + { + typedef Dune::DiagonalMatrix<T,n> A; + typedef A C; + typedef C D; + static void addProduct(A& a, const T& b, const C& c, const D& d) + { + for (size_t i=0; i<n; ++i) + a.diagonal(i) += b * c.diagonal(i)*d.diagonal(i); + } + }; + template<class T, int n> struct ProductHelper<Dune::FieldMatrix<T,n,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false> { @@ -158,7 +199,20 @@ namespace Arithmetic } }; - /** \brief Specialization for b being a scalar type and (a not a matrix type) + template<class T, int n> + struct ScaledProductHelper<Dune::FieldMatrix<T,n,n>, T, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false> + { + typedef Dune::FieldMatrix<T,n,n> A; + typedef Dune::DiagonalMatrix<T,n> C; + typedef C D; + static void addProduct(A& a, const T& b, const C& c, const D& d) + { + for (size_t i=0; i<n; ++i) + a[i][i] += b * c.diagonal(i)*d.diagonal(i); + } + }; + + /** \brief Specialization for b being a scalar type */ template<class A, class B, class C, bool AisScalar, bool CisScalar> struct ProductHelper<A, B, C, AisScalar, true, CisScalar> @@ -169,6 +223,17 @@ namespace Arithmetic } }; + /** \brief Specialization for c being a scalar type + */ + template<class A, class B, class C, class D, bool AisScalar, bool DisScalar> + struct ScaledProductHelper<A, B, C, D, AisScalar, true, DisScalar> + { + static void addProduct(A& a, const B& b, const C& c, const D& d) + { + a.axpy(b * c, d); + } + }; + template<class A, class B, class C> struct ProductHelper<A, B, C, true, true, true> { @@ -178,13 +243,20 @@ namespace Arithmetic } }; - + template<class A, class B, class C, class D> + struct ScaledProductHelper<A, B, C, D, true, true, true> + { + static void addProduct(A& a, const B& b, const C& c, const D& d) + { + a += b*c*d; + } + }; /** \brief Add a product to some matrix or vector * * This function computes a+=b*c. * - * This functions should tolerate all meaningful + * This function should tolerate all meaningful * combinations of scalars, vectors, and matrices. * * a,b,c could be matrices with appropriate @@ -197,7 +269,23 @@ namespace Arithmetic ProductHelper<A,B,C,ScalarTraits<A>::isScalar, ScalarTraits<B>::isScalar, ScalarTraits<C>::isScalar>::addProduct(a,b,c); } - + /** \brief Add a scaled product to some matrix or vector + * + * This function computes a+=b*c. + * + * This function should tolerate all meaningful + * combinations of scalars, vectors, and matrices. + * + * a,c,d could be matrices with appropriate dimensions. But b must + * (currently) and c can also always be a scalar represented by a + * 1-dim vector or a 1 by 1 matrix. + */ + template<class A, class B, class C, class D> + typename Dune::enable_if<ScalarTraits<B>::isScalar, void>::type + addProduct(A& a, const B& b, const C& c, const D& d) + { + ScaledProductHelper<A,B,C,D,ScalarTraits<A>::isScalar, ScalarTraits<C>::isScalar, ScalarTraits<D>::isScalar>::addProduct(a,b,c,d); + } }; #endif -- GitLab