Skip to content
Snippets Groups Projects
Commit 27327d48 authored by Elias Pipping's avatar Elias Pipping Committed by pipping
Browse files

arithmetic: sync with dune-fufem

[[Imported from SVN: r12194]]
parent c73fa760
Branches
Tags
No related merge requests found
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <dune/common/fmatrix.hh> #include <dune/common/fmatrix.hh>
#include <dune/common/diagonalmatrix.hh> #include <dune/common/diagonalmatrix.hh>
#include <dune/istl/bcrsmatrix.hh> #include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/matrixindexset.hh>
#include <dune/istl/scaledidmatrix.hh> #include <dune/istl/scaledidmatrix.hh>
#include <dune/common/typetraits.hh> #include <dune/common/typetraits.hh>
...@@ -20,7 +21,7 @@ namespace Arithmetic ...@@ -20,7 +21,7 @@ namespace Arithmetic
/** \brief Class to identify scalar types /** \brief Class to identify scalar types
* *
* Specialize this class for all type that can be used * Specialize this class for all types that can be used
* like scalar quantities. * like scalar quantities.
*/ */
template<class T> template<class T>
...@@ -68,7 +69,7 @@ namespace Arithmetic ...@@ -68,7 +69,7 @@ namespace Arithmetic
/** \brief Class to identify matrix types and extract information /** \brief Class to identify matrix types and extract information
* *
* Specialize this class for all type that can be used like a matrix. * Specialize this class for all types that can be used like a matrix.
*/ */
template<class T> template<class T>
struct MatrixTraits struct MatrixTraits
...@@ -102,6 +103,12 @@ namespace Arithmetic ...@@ -102,6 +103,12 @@ namespace Arithmetic
enum { cols=n}; enum { cols=n};
}; };
template<class T>
struct MatrixTraits<Dune::BCRSMatrix<T> >
{
enum { isMatrix=true };
};
/** \brief Internal helper class for product operations /** \brief Internal helper class for product operations
* *
...@@ -317,6 +324,84 @@ namespace Arithmetic ...@@ -317,6 +324,84 @@ namespace Arithmetic
} }
}; };
//! Helper class for computing the cross product
template <class T, int n>
struct CrossProductHelper
{
static Dune::FieldVector<T,n> crossProduct(const Dune::FieldVector<T,n>& a, const Dune::FieldVector<T,n>& b)
{
DUNE_THROW(Dune::Exception, "You can only call crossProduct with dim==3");
}
};
//! Specialisation for n=3
template <class T>
struct CrossProductHelper<T,3>
{
static Dune::FieldVector<T,3> crossProduct(const Dune::FieldVector<T,3>& a, const Dune::FieldVector<T,3>& b)
{
Dune::FieldVector<T,3> r;
r[0] = a[1]*b[2] - a[2]*b[1];
r[1] = a[2]*b[0] - a[0]*b[2];
r[2] = a[0]*b[1] - a[1]*b[0];
return r;
}
};
template<class A, class TransposedA>
struct TransposeHelper
{
static void transpose(const A& a, TransposedA& aT)
{
DUNE_THROW(Dune::Exception, "Not implemented for general matrix types!");
}
};
//! Specialization for Dune::FieldMatrix
template<class T, int n, int m>
struct TransposeHelper<Dune::FieldMatrix<T,n,m>, Dune::FieldMatrix<T,m,n> >
{
static void transpose(const Dune::FieldMatrix<T,n,m>& a, Dune::FieldMatrix<T,m,n>& aT)
{
for (int row = 0; row < m; ++row)
for (int col = 0 ; col < n; ++col)
aT[row][col] = a[col][row];
}
};
//! Specialization for Dune::BCRSMatrix Type
template<class A, class TransposedA>
struct TransposeHelper<Dune::BCRSMatrix<A>, Dune::BCRSMatrix<TransposedA> >
{
static void transpose(const Dune::BCRSMatrix<A>& a, Dune::BCRSMatrix<TransposedA>& aT)
{
Dune::MatrixIndexSet idxSetaT(a.M(),a.N());
typedef typename Dune::BCRSMatrix<A>::ConstColIterator ColIterator;
// add indices into transposed matrix
for (int row = 0; row < a.N(); ++row)
{
ColIterator col = a[row].begin();
ColIterator end = a[row].end();
for ( ; col != end; ++col)
idxSetaT.add(col.index(),row);
}
idxSetaT.exportIdx(aT);
for (int row = 0; row < a.N(); ++row)
{
ColIterator col = a[row].begin();
ColIterator end = a[row].end();
for ( ; col != end; ++col)
TransposeHelper<A,TransposedA>::transpose(*col, aT[col.index()][row]);
}
}
};
/** \brief Add a product to some matrix or vector /** \brief Add a product to some matrix or vector
* *
* This function computes a+=b*c. * This function computes a+=b*c.
...@@ -350,6 +435,19 @@ namespace Arithmetic ...@@ -350,6 +435,19 @@ namespace Arithmetic
{ {
ProductHelper<A,B,C,ScalarTraits<A>::isScalar, ScalarTraits<B>::isScalar, ScalarTraits<C>::isScalar>::subtractProduct(a,b,c); ProductHelper<A,B,C,ScalarTraits<A>::isScalar, ScalarTraits<B>::isScalar, ScalarTraits<C>::isScalar>::subtractProduct(a,b,c);
} }
//! Compute the cross product of two vectors. Only works for n==3
template<class T, int n>
Dune::FieldVector<T,n> crossProduct(const Dune::FieldVector<T,n>& a, const Dune::FieldVector<T,n>& b)
{
return CrossProductHelper<T,n>::crossProduct(a,b);
}
//! Compute the transposed of a matrix
template <class A, class TransposedA>
void transpose(const A& a, TransposedA& aT)
{
TransposeHelper<A,TransposedA>::transpose(a,aT);
}
/** \brief Add a scaled product to some matrix or vector /** \brief Add a scaled product to some matrix or vector
* *
...@@ -387,12 +485,17 @@ namespace Arithmetic ...@@ -387,12 +485,17 @@ namespace Arithmetic
ScaledProductHelper<A,B,C,D,ScalarTraits<A>::isScalar, ScalarTraits<C>::isScalar, ScalarTraits<D>::isScalar>::subtractProduct(a,b,c,d); ScaledProductHelper<A,B,C,D,ScalarTraits<A>::isScalar, ScalarTraits<C>::isScalar, ScalarTraits<D>::isScalar>::subtractProduct(a,b,c,d);
} }
namespace { /** \brief Internal helper class for Matrix operations
template <class MatrixType, class VectorType> *
typename VectorType::field_type */
AxyWithTemporary(const MatrixType &A, template<class OperatorType, bool isMatrix>
const VectorType &x, struct OperatorHelper
const VectorType &y) {
template <class VectorType>
static typename VectorType::field_type
Axy(const OperatorType &A,
const VectorType &x,
const VectorType &y)
{ {
VectorType tmp(y.size()); VectorType tmp(y.size());
tmp = 0.0; tmp = 0.0;
...@@ -400,11 +503,25 @@ namespace Arithmetic ...@@ -400,11 +503,25 @@ namespace Arithmetic
return tmp * y; return tmp * y;
} }
template <class MatrixType, class VectorType> template <class VectorType>
typename VectorType::field_type static typename VectorType::field_type
AxyWithoutTemporary(const MatrixType &A, bmAxy(const OperatorType &A, const VectorType &b,
const VectorType &x, const VectorType &x, const VectorType &y)
const VectorType &y) {
VectorType tmp = b;
subtractProduct(tmp, A, x);
return tmp * y;
}
};
template<class MatrixType>
struct OperatorHelper<MatrixType, true>
{
template <class VectorType>
static typename VectorType::field_type
Axy(const MatrixType &A,
const VectorType &x,
const VectorType &y)
{ {
assert(x.N() == A.M()); assert(x.N() == A.M());
assert(y.N() == A.N()); assert(y.N() == A.N());
...@@ -423,53 +540,11 @@ namespace Arithmetic ...@@ -423,53 +540,11 @@ namespace Arithmetic
} }
return outer; return outer;
} }
}
//! Compute \f$(Ax,y)\f$
template <class MatrixType, class VectorType>
typename VectorType::field_type
Axy(const MatrixType &A,
const VectorType &x,
const VectorType &y)
{
return AxyWithTemporary(A, x, y);
}
//! Compute \f$(Ax,y)\f$
template <int n, typename field_type, class VectorType>
typename VectorType::field_type
Axy(const Dune::FieldMatrix<field_type, n, n> &A,
const VectorType &x,
const VectorType &y)
{
return AxyWithoutTemporary(A, x, y);
}
//! Compute \f$(Ax,y)\f$
template <typename block_type, class VectorType>
typename VectorType::field_type
Axy(const Dune::BCRSMatrix<block_type> &A,
const VectorType &x,
const VectorType &y)
{
return AxyWithoutTemporary(A, x, y);
}
namespace {
template <class MatrixType, class VectorType>
typename VectorType::field_type
bmAxyWithTemporary(const MatrixType &A, const VectorType &b,
const VectorType &x, const VectorType &y)
{
VectorType tmp = b;
subtractProduct(tmp, A, x);
return tmp * y;
}
template <class MatrixType, class VectorType> template <class VectorType>
typename VectorType::field_type static typename VectorType::field_type
bmAxyWithoutTemporary(const MatrixType &A, const VectorType &b, bmAxy(const MatrixType &A, const VectorType &b,
const VectorType &x, const VectorType &y) const VectorType &x, const VectorType &y)
{ {
assert(x.N() == A.M()); assert(x.N() == A.M());
assert(y.N() == A.N()); assert(y.N() == A.N());
...@@ -489,33 +564,29 @@ namespace Arithmetic ...@@ -489,33 +564,29 @@ namespace Arithmetic
} }
return outer; return outer;
} }
} };
//! Compute \f$(b-Ax,y)\f$
template <class MatrixType, class VectorType>
typename VectorType::field_type
bmAxy(const MatrixType &A, const VectorType &b,
const VectorType &x, const VectorType &y)
{
return bmAxyWithTemporary(A, b, x, y);
}
//! Compute \f$(b-Ax,y)\f$ //! Compute \f$(Ax,y)\f$
template <int n, typename field_type, class VectorType> template <class OperatorType, class VectorType>
typename VectorType::field_type typename VectorType::field_type
bmAxy(const Dune::FieldMatrix<field_type, n, n> &A, const VectorType &b, Axy(const OperatorType &A,
const VectorType &x, const VectorType &y) const VectorType &x,
const VectorType &y)
{ {
return bmAxyWithoutTemporary(A, b, x, y); return OperatorHelper<OperatorType,
MatrixTraits<OperatorType>::isMatrix>::Axy(A, x, y);
} }
//! Compute \f$(b-Ax,y)\f$ //! Compute \f$(b-Ax,y)\f$
template <typename block_type, class VectorType> template <class OperatorType, class VectorType>
typename VectorType::field_type typename VectorType::field_type
bmAxy(const Dune::BCRSMatrix<block_type> &A, const VectorType &b, bmAxy(const OperatorType &A,
const VectorType &x, const VectorType &y) const VectorType &b,
const VectorType &x,
const VectorType &y)
{ {
return bmAxyWithoutTemporary(A, b, x, y); return OperatorHelper<OperatorType,
MatrixTraits<OperatorType>::isMatrix>::bmAxy(A, b, x, y);
} }
}; };
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment