Skip to content
Snippets Groups Projects
Commit 33df8947 authored by Elias Pipping's avatar Elias Pipping Committed by pipping
Browse files

arithmetic: Sync with fufem

[[Imported from SVN: r8500]]
parent 1f24a6fb
No related branches found
No related tags found
No related merge requests found
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
#define ARITHMETIC_HH #define ARITHMETIC_HH
#include <dune/common/diagonalmatrix.hh>
#include <dune/common/fvector.hh> #include <dune/common/fvector.hh>
#include <dune/common/fmatrix.hh> #include <dune/common/fmatrix.hh>
#include <dune/common/diagonalmatrix.hh>
#include <dune/istl/scaledidmatrix.hh> #include <dune/istl/scaledidmatrix.hh>
#include <dune/common/typetraits.hh>
/** \brief Namespace containing helper classes and functions for arithmetic operations /** \brief Namespace containing helper classes and functions for arithmetic operations
* *
...@@ -113,6 +114,16 @@ namespace Arithmetic ...@@ -113,6 +114,16 @@ namespace Arithmetic
} }
}; };
// Internal helper class for scaled product operations (i.e., b is always a scalar)
template<class A, class B, class C, class D, bool AisScalar, bool CisScalar, bool DisScalar>
struct ScaledProductHelper
{
static void addProduct(A& a, const B& b, const C& c, const D& d)
{
c.usmv(b, d, a);
}
};
template<class T, int n, int m, int k> template<class T, int n, int m, int k>
struct ProductHelper<Dune::FieldMatrix<T,n,k>, Dune::FieldMatrix<T,n,m>, Dune::FieldMatrix<T,m,k>, false, false, false> struct ProductHelper<Dune::FieldMatrix<T,n,k>, Dune::FieldMatrix<T,n,m>, Dune::FieldMatrix<T,m,k>, false, false, false>
{ {
...@@ -132,6 +143,23 @@ namespace Arithmetic ...@@ -132,6 +143,23 @@ namespace Arithmetic
} }
}; };
template<class T, int n, int m, int k>
struct ScaledProductHelper<Dune::FieldMatrix<T,n,k>, T, Dune::FieldMatrix<T,n,m>, Dune::FieldMatrix<T,m,k>, false, false, false>
{
typedef Dune::FieldMatrix<T,n,k> A;
typedef Dune::FieldMatrix<T,n,m> C;
typedef Dune::FieldMatrix<T,m,k> D;
static void addProduct(A& a, const T& b, const C& c, const D& d)
{
for (size_t row = 0; row < n; ++row) {
for (size_t col = 0 ; col < k; ++col) {
for (size_t i = 0; i < m; ++i)
a[row][col] += b * c[row][i]*d[i][col];
}
}
}
};
template<class T, int n> template<class T, int n>
struct ProductHelper<Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false> struct ProductHelper<Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false>
{ {
...@@ -145,6 +173,19 @@ namespace Arithmetic ...@@ -145,6 +173,19 @@ namespace Arithmetic
} }
}; };
template<class T, int n>
struct ScaledProductHelper<Dune::DiagonalMatrix<T,n>, T, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false>
{
typedef Dune::DiagonalMatrix<T,n> A;
typedef A C;
typedef C D;
static void addProduct(A& a, const T& b, const C& c, const D& d)
{
for (size_t i=0; i<n; ++i)
a.diagonal(i) += b * c.diagonal(i)*d.diagonal(i);
}
};
template<class T, int n> template<class T, int n>
struct ProductHelper<Dune::FieldMatrix<T,n,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false> struct ProductHelper<Dune::FieldMatrix<T,n,n>, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false>
{ {
...@@ -158,7 +199,20 @@ namespace Arithmetic ...@@ -158,7 +199,20 @@ namespace Arithmetic
} }
}; };
/** \brief Specialization for b being a scalar type and (a not a matrix type) template<class T, int n>
struct ScaledProductHelper<Dune::FieldMatrix<T,n,n>, T, Dune::DiagonalMatrix<T,n>, Dune::DiagonalMatrix<T,n>, false, false, false>
{
typedef Dune::FieldMatrix<T,n,n> A;
typedef Dune::DiagonalMatrix<T,n> C;
typedef C D;
static void addProduct(A& a, const T& b, const C& c, const D& d)
{
for (size_t i=0; i<n; ++i)
a[i][i] += b * c.diagonal(i)*d.diagonal(i);
}
};
/** \brief Specialization for b being a scalar type
*/ */
template<class A, class B, class C, bool AisScalar, bool CisScalar> template<class A, class B, class C, bool AisScalar, bool CisScalar>
struct ProductHelper<A, B, C, AisScalar, true, CisScalar> struct ProductHelper<A, B, C, AisScalar, true, CisScalar>
...@@ -169,6 +223,17 @@ namespace Arithmetic ...@@ -169,6 +223,17 @@ namespace Arithmetic
} }
}; };
/** \brief Specialization for c being a scalar type
*/
template<class A, class B, class C, class D, bool AisScalar, bool DisScalar>
struct ScaledProductHelper<A, B, C, D, AisScalar, true, DisScalar>
{
static void addProduct(A& a, const B& b, const C& c, const D& d)
{
a.axpy(b * c, d);
}
};
template<class A, class B, class C> template<class A, class B, class C>
struct ProductHelper<A, B, C, true, true, true> struct ProductHelper<A, B, C, true, true, true>
{ {
...@@ -178,13 +243,20 @@ namespace Arithmetic ...@@ -178,13 +243,20 @@ namespace Arithmetic
} }
}; };
template<class A, class B, class C, class D>
struct ScaledProductHelper<A, B, C, D, true, true, true>
{
static void addProduct(A& a, const B& b, const C& c, const D& d)
{
a += b*c*d;
}
};
/** \brief Add a product to some matrix or vector /** \brief Add a product to some matrix or vector
* *
* This function computes a+=b*c. * This function computes a+=b*c.
* *
* This functions should tolerate all meaningful * This function should tolerate all meaningful
* combinations of scalars, vectors, and matrices. * combinations of scalars, vectors, and matrices.
* *
* a,b,c could be matrices with appropriate * a,b,c could be matrices with appropriate
...@@ -197,7 +269,23 @@ namespace Arithmetic ...@@ -197,7 +269,23 @@ namespace Arithmetic
ProductHelper<A,B,C,ScalarTraits<A>::isScalar, ScalarTraits<B>::isScalar, ScalarTraits<C>::isScalar>::addProduct(a,b,c); ProductHelper<A,B,C,ScalarTraits<A>::isScalar, ScalarTraits<B>::isScalar, ScalarTraits<C>::isScalar>::addProduct(a,b,c);
} }
/** \brief Add a scaled product to some matrix or vector
*
* This function computes a+=b*c.
*
* This function should tolerate all meaningful
* combinations of scalars, vectors, and matrices.
*
* a,c,d could be matrices with appropriate dimensions. But b must
* (currently) and c can also always be a scalar represented by a
* 1-dim vector or a 1 by 1 matrix.
*/
template<class A, class B, class C, class D>
typename Dune::enable_if<ScalarTraits<B>::isScalar, void>::type
addProduct(A& a, const B& b, const C& c, const D& d)
{
ScaledProductHelper<A,B,C,D,ScalarTraits<A>::isScalar, ScalarTraits<C>::isScalar, ScalarTraits<D>::isScalar>::addProduct(a,b,c,d);
}
}; };
#endif #endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment