diff --git a/dune/solvers/common/arithmetic.hh b/dune/solvers/common/arithmetic.hh new file mode 100644 index 0000000000000000000000000000000000000000..bc05477d7598c2ed528878a595f4b911d6d844c5 --- /dev/null +++ b/dune/solvers/common/arithmetic.hh @@ -0,0 +1,203 @@ +#ifndef ARITHMETIC_HH +#define ARITHMETIC_HH + + +#include <dune/common/fvector.hh> +#include <dune/common/fmatrix.hh> +#include <dune/istl/diagonalmatrix.hh> +#include <dune/istl/scaledidmatrix.hh> + +/** \brief Namespace containing helper classes and functions for arithmetic operations + * + * Everything in this namespace is experimental and might change + * in the near future. Especially the naming of namespace, structs, + * and functions is not final. + */ +namespace Arithmetic +{ + + /** \brief Class to identify scalar types + * + * Specialize this class for all type that can be used + * like scalar quantities. + */ + template<class T> + struct ScalarTraits + { + enum { isScalar=false }; + }; + + template<> + struct ScalarTraits<float> + { + enum { isScalar=true }; + }; + + template<> + struct ScalarTraits<double> + { + enum { isScalar=true }; + }; + + template<class T> + struct ScalarTraits<Dune::FieldVector<T,1> > + { + enum { isScalar=true }; + }; + + template<class T> + struct ScalarTraits<Dune::FieldMatrix<T,1,1> > + { + enum { isScalar=true }; + }; + + template<class T> + struct ScalarTraits<Dune::DiagonalMatrix<T,1> > + { + enum { isScalar=true }; + }; + + template<class T> + struct ScalarTraits<Dune::ScaledIdentityMatrix<T,1> > + { + enum { isScalar=true }; + }; + + + /** \brief Class to identify matrix types and extract information + * + * Specialize this class for all type that can be used like a matrix. + */ + template<class T> + struct MatrixTraits + { + enum { isMatrix=false }; + enum { rows=-1}; + enum { cols=-1}; + }; + + template<class T, int n, int m> + struct MatrixTraits<Dune::FieldMatrix<T,n,m> > + { + enum { isMatrix=true }; + enum { rows=n}; + enum { cols=m}; + }; + + template<class T, int n> + struct MatrixTraits<Dune::DiagonalMatrix<T,n> > + { + enum { isMatrix=true }; + enum { rows=n}; + enum { cols=n}; + }; + + template<class T, int n> + struct MatrixTraits<Dune::ScaledIdentityMatrix<T,n> > + { + enum { isMatrix=true }; + enum { rows=n}; + enum { cols=n}; + }; + + + /** \brief Internal helper class for product operations + * + */ + template<class A, class B, class C, bool AisScalar, bool BisScalar, bool CisScalar> + struct ProductHelper + { + static void addProduct(A& a, const B& b, const C& c) + { + b.umv(c, a); + } + }; + + template<class T, int n, int m, int k> + struct ProductHelper<Dune::FieldMatrix<T,n,k>, Dune::FieldMatrix<T,n,m>, Dune::FieldMatrix<T,m,k>, false, false, false> + { + typedef Dune::FieldMatrix<T,n,k> A; + typedef Dune::FieldMatrix<T,n,m> B; + typedef Dune::FieldMatrix<T,m,k> C; + static void addProduct(A& a, const B& b, const C& c) + { + for (size_t row = 0; row < n; ++row) + { + for (size_t col = 0 ; col < k; ++col) + { + for (size_t i = 0; i < m; ++i) + a[row][col] += b[row][i]*c[i][col]; + } + } + } + }; + + template<class T, int n> + struct ProductHelper<Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false> + { + typedef Dune::DiagonalMatrix<T,n> A; + typedef A B; + typedef B C; + static void addProduct(A& a, const B& b, const C& c) + { + for (size_t i=0; i<n; ++i) + a.diagonal(i) += b.diagonal(i)*c.diagonal(i); + } + }; + + template<class T, int n> + struct ProductHelper<Dune::FieldMatrix<T,n,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false> + { + typedef Dune::FieldMatrix<T,n,n> A; + typedef Dune::DiagonalMatrix<T,n> B; + typedef B C; + static void addProduct(A& a, const B& b, const C& c) + { + for (size_t i=0; i<n; ++i) + a[i][i] += b.diagonal(i)*c.diagonal(i); + } + }; + + /** \brief Specialization for b being a scalar type and (a not a matrix type) + */ + template<class A, class B, class C, bool AisScalar, bool CisScalar> + struct ProductHelper<A, B, C, AisScalar, true, CisScalar> + { + static void addProduct(A& a, const B& b, const C& c) + { + a.axpy(b, c); + } + }; + + template<class A, class B, class C> + struct ProductHelper<A, B, C, true, true, true> + { + static void addProduct(A& a, const B& b, const C& c) + { + a += b*c; + } + }; + + + + /** \brief Add a product to some matrix or vector + * + * This function computes a+=b*c. + * + * This functions should tolerate all meaningful + * combinations of scalars, vectors, and matrices. + * + * a,b,c could be matrices with appropriate + * dimensions. But b can also always be a scalar + * represented by a 1-dim vector or a 1 by 1 matrix. + */ + template<class A, class B, class C> + void addProduct(A& a, const B& b, const C& c) + { + ProductHelper<A,B,C,ScalarTraits<A>::isScalar, ScalarTraits<B>::isScalar, ScalarTraits<C>::isScalar>::addProduct(a,b,c); + } + + +}; + +#endif