Skip to content
Snippets Groups Projects
Commit 952c54cd authored by Max Kahnt's avatar Max Kahnt
Browse files

Adapt to moved traits and use traits utilities.

parent 1d64d27b
Branches
No related tags found
1 merge request!5Feature/organize traits
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
#define DUNE_MATRIX_VECTOR_ALGORITHM_HH #define DUNE_MATRIX_VECTOR_ALGORITHM_HH
#include <dune/common/hybridutilities.hh> #include <dune/common/hybridutilities.hh>
#include <dune/matrix-vector/traitutilities.hh> #include <dune/matrix-vector/traits/utilities.hh>
namespace Dune { namespace Dune {
namespace MatrixVector { namespace MatrixVector {
//! \brief Hybrid for loop over sparse range (static/tuple-like candidate) //! \brief Hybrid for loop over sparse range (static/tuple-like candidate)
template <class Range, class F, std::enable_if_t<isTupleOrDerived<Range>(), int> = 0> template <class Range, class F, EnableTupleOrDerived<Range, int> = 0>
void sparseRangeFor(Range&& range, F&& f) { void sparseRangeFor(Range&& range, F&& f) {
using namespace Dune::Hybrid; using namespace Dune::Hybrid;
forEach(integralRange(size(range)), [&](auto&& i) { forEach(integralRange(size(range)), [&](auto&& i) {
...@@ -19,7 +19,7 @@ void sparseRangeFor(Range&& range, F&& f) { ...@@ -19,7 +19,7 @@ void sparseRangeFor(Range&& range, F&& f) {
} }
//! \brief Hybrid for loop over sparse range (dynamic/sparse candidate) //! \brief Hybrid for loop over sparse range (dynamic/sparse candidate)
template<class Range, class F, std::enable_if_t<not isTupleOrDerived<Range>(), int> = 0> template<class Range, class F, DisableTupleOrDerived<Range, int> = 0>
void sparseRangeFor(Range&& range, F&& f) void sparseRangeFor(Range&& range, F&& f)
{ {
auto it = range.begin(); auto it = range.begin();
...@@ -29,12 +29,12 @@ void sparseRangeFor(Range&& range, F&& f) ...@@ -29,12 +29,12 @@ void sparseRangeFor(Range&& range, F&& f)
} }
//! \brief Hybrid access to first sparse range element (static/tuple-like candiate) //! \brief Hybrid access to first sparse range element (static/tuple-like candiate)
template <class Range, class F, std::enable_if_t<isTupleOrDerived<Range>(), int> = 0> template <class Range, class F, EnableTupleOrDerived<Range, int> = 0>
void sparseRangeFirst(Range&& range, F&& f) { void sparseRangeFirst(Range&& range, F&& f) {
f(range[Indices::_0]); f(range[Indices::_0]);
} }
//! \brief Hybrid access to first sparse range element (dynamic/sparse candiate) //! \brief Hybrid access to first sparse range element (dynamic/sparse candiate)
template<class Range, class F, std::enable_if_t<not isTupleOrDerived<Range>(), int> = 0> template <class Range, class F, DisableTupleOrDerived<Range, int> = 0>
void sparseRangeFirst(Range&& range, F&& f) void sparseRangeFirst(Range&& range, F&& f)
{ {
f(*range.begin()); f(*range.begin());
......
...@@ -8,9 +8,8 @@ ...@@ -8,9 +8,8 @@
#include <dune/istl/bcrsmatrix.hh> #include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/scaledidmatrix.hh> #include <dune/istl/scaledidmatrix.hh>
#include "algorithm.hh" #include <dune/matrix-vector/algorithm.hh>
#include "matrixtraits.hh" #include <dune/matrix-vector/traits/utilities.hh>
#include "scalartraits.hh"
namespace Dune { namespace Dune {
namespace MatrixVector { namespace MatrixVector {
...@@ -38,8 +37,8 @@ namespace MatrixVector { ...@@ -38,8 +37,8 @@ namespace MatrixVector {
*/ */
template <class A, class B, class C> template <class A, class B, class C>
void addProduct(A& a, const B& b, const C& c) { void addProduct(A& a, const B& b, const C& c) {
ProductHelper<A, B, C, ScalarTraits<A>::isScalar, ScalarTraits<B>::isScalar, ProductHelper<A, B, C, isScalar<A>(), isScalar<B>(),
ScalarTraits<C>::isScalar>::addProduct(a, b, c); isScalar<C>()>::addProduct(a, b, c);
} }
/** \brief Subtract a product from some matrix or vector /** \brief Subtract a product from some matrix or vector
...@@ -55,9 +54,8 @@ namespace MatrixVector { ...@@ -55,9 +54,8 @@ namespace MatrixVector {
*/ */
template <class A, class B, class C> template <class A, class B, class C>
void subtractProduct(A& a, const B& b, const C& c) { void subtractProduct(A& a, const B& b, const C& c) {
ScaledProductHelper<A, int, B, C, ScalarTraits<A>::isScalar, ScaledProductHelper<A, int, B, C, isScalar<A>(), isScalar<B>(),
ScalarTraits<B>::isScalar, isScalar<C>()>::addProduct(a, -1, b, c);
ScalarTraits<C>::isScalar>::addProduct(a, -1, b, c);
} }
/** \brief Add a scaled product to some matrix or vector /** \brief Add a scaled product to some matrix or vector
...@@ -72,11 +70,10 @@ namespace MatrixVector { ...@@ -72,11 +70,10 @@ namespace MatrixVector {
* 1-dim vector or a 1 by 1 matrix. * 1-dim vector or a 1 by 1 matrix.
*/ */
template <class A, class B, class C, class D> template <class A, class B, class C, class D>
typename std::enable_if_t<ScalarTraits<B>::isScalar, void> addProduct( EnableScalar<B> addProduct(
A& a, const B& b, const C& c, const D& d) { A& a, const B& b, const C& c, const D& d) {
ScaledProductHelper<A, B, C, D, ScalarTraits<A>::isScalar, ScaledProductHelper<A, B, C, D, isScalar<A>(), isScalar<C>(),
ScalarTraits<C>::isScalar, isScalar<D>()>::addProduct(a, b, c, d);
ScalarTraits<D>::isScalar>::addProduct(a, b, c, d);
} }
/** \brief Subtract a scaled product from some matrix or vector /** \brief Subtract a scaled product from some matrix or vector
...@@ -91,11 +88,10 @@ namespace MatrixVector { ...@@ -91,11 +88,10 @@ namespace MatrixVector {
* 1-dim vector or a 1 by 1 matrix. * 1-dim vector or a 1 by 1 matrix.
*/ */
template <class A, class B, class C, class D> template <class A, class B, class C, class D>
typename std::enable_if_t<ScalarTraits<B>::isScalar, void> EnableScalar<B>
subtractProduct(A& a, const B& b, const C& c, const D& d) { subtractProduct(A& a, const B& b, const C& c, const D& d) {
ScaledProductHelper<A, B, C, D, ScalarTraits<A>::isScalar, ScaledProductHelper<A, B, C, D, isScalar<A>(), isScalar<C>(),
ScalarTraits<C>::isScalar, isScalar<D>()>::addProduct(a, -b, c, d);
ScalarTraits<D>::isScalar>::addProduct(a, -b, c, d);
} }
/** \brief Internal helper class for product operations /** \brief Internal helper class for product operations
...@@ -104,16 +100,12 @@ namespace MatrixVector { ...@@ -104,16 +100,12 @@ namespace MatrixVector {
template <class A, class B, class C, bool AisScalar, bool BisScalar, template <class A, class B, class C, bool AisScalar, bool BisScalar,
bool CisScalar> bool CisScalar>
struct ProductHelper { struct ProductHelper {
template < template <class ADummy = A, DisableMatrix<ADummy, int> = 0>
class ADummy = A,
std::enable_if_t<!MatrixTraits<ADummy>::isMatrix, int> SFINAE_Dummy = 0>
static void addProduct(A& a, const B& b, const C& c) { static void addProduct(A& a, const B& b, const C& c) {
b.umv(c, a); b.umv(c, a);
} }
template < template <class ADummy = A, EnableMatrix<ADummy, int> = 0>
class ADummy = A,
std::enable_if_t<MatrixTraits<ADummy>::isMatrix, int> SFINAE_Dummy = 0>
static void addProduct(A& a, const B& b, const C& c) { static void addProduct(A& a, const B& b, const C& c) {
sparseRangeFor(b, [&](auto&& bi, auto&& i) { sparseRangeFor(b, [&](auto&& bi, auto&& i) {
sparseRangeFor(bi, [&](auto&& bik, auto&& k) { sparseRangeFor(bi, [&](auto&& bik, auto&& k) {
...@@ -130,16 +122,12 @@ namespace MatrixVector { ...@@ -130,16 +122,12 @@ namespace MatrixVector {
template <class A, class Scalar, class B, class C, bool AisScalar, template <class A, class Scalar, class B, class C, bool AisScalar,
bool BisScalar, bool CisScalar> bool BisScalar, bool CisScalar>
struct ScaledProductHelper { struct ScaledProductHelper {
template < template <class ADummy = A, DisableMatrix<ADummy, int> = 0>
class ADummy = A,
std::enable_if_t<!MatrixTraits<ADummy>::isMatrix, int> SFINAE_Dummy = 0>
static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) { static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) {
b.usmv(scalar, c, a); b.usmv(scalar, c, a);
} }
template < template <class ADummy = A, EnableMatrix<ADummy, int> = 0>
class ADummy = A,
std::enable_if_t<MatrixTraits<ADummy>::isMatrix, int> SFINAE_Dummy = 0>
static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) { static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) {
sparseRangeFor(b, [&](auto&& bi, auto&& i) { sparseRangeFor(b, [&](auto&& bi, auto&& i) {
sparseRangeFor(bi, [&](auto&& bik, auto&& k) { sparseRangeFor(bi, [&](auto&& bik, auto&& k) {
...@@ -317,16 +305,12 @@ namespace MatrixVector { ...@@ -317,16 +305,12 @@ namespace MatrixVector {
struct ProductHelper<A, ScalarB, C, AisScalar, true, CisScalar> { struct ProductHelper<A, ScalarB, C, AisScalar, true, CisScalar> {
typedef ScalarB B; typedef ScalarB B;
template < template <class ADummy = A, DisableMatrix<ADummy, int> = 0>
class ADummy = A,
std::enable_if_t<!MatrixTraits<ADummy>::isMatrix, int> SFINAE_Dummy = 0>
static void addProduct(A& a, const B& b, const C& c) { static void addProduct(A& a, const B& b, const C& c) {
a.axpy(b, c); a.axpy(b, c);
} }
template < template <class ADummy = A, EnableMatrix<ADummy, int> = 0>
class ADummy = A,
std::enable_if_t<MatrixTraits<ADummy>::isMatrix, int> SFINAE_Dummy = 0>
static void addProduct(A& a, const B& b, const C& c) { static void addProduct(A& a, const B& b, const C& c) {
sparseRangeFor(c, [&](auto&& ci, auto && i) { sparseRangeFor(c, [&](auto&& ci, auto && i) {
sparseRangeFor(ci, [&](auto&& cij, auto && j) { sparseRangeFor(ci, [&](auto&& cij, auto && j) {
...@@ -342,16 +326,12 @@ namespace MatrixVector { ...@@ -342,16 +326,12 @@ namespace MatrixVector {
CisScalar> { CisScalar> {
typedef ScalarB B; typedef ScalarB B;
template < template <class ADummy = A, DisableMatrix<ADummy, int> = 0>
class ADummy = A,
std::enable_if_t<!MatrixTraits<ADummy>::isMatrix, int> SFINAE_Dummy = 0>
static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) { static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) {
a.axpy(scalar * b, c); a.axpy(scalar * b, c);
} }
template < template <class ADummy = A, EnableMatrix<ADummy, int> = 0>
class ADummy = A,
std::enable_if_t<MatrixTraits<ADummy>::isMatrix, int> SFINAE_Dummy = 0>
static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) { static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) {
sparseRangeFor(c, [&](auto&& ci, auto&& i) { sparseRangeFor(c, [&](auto&& ci, auto&& i) {
sparseRangeFor(ci, [&](auto&& cij, auto&& j) { sparseRangeFor(ci, [&](auto&& cij, auto&& j) {
...@@ -398,6 +378,7 @@ namespace MatrixVector { ...@@ -398,6 +378,7 @@ namespace MatrixVector {
} }
}; };
} } // end namespace MatrixVector
} } // end namespace Dune
#endif
#endif // DUNE_MATRIX_VECTOR_AXPY_HH
...@@ -3,17 +3,21 @@ ...@@ -3,17 +3,21 @@
#include <cassert> #include <cassert>
#include "axpy.hh" #include <dune/matrix-vector/algorithm.hh>
#include "matrixtraits.hh" #include <dune/matrix-vector/axpy.hh>
#include "algorithm.hh" #include <dune/matrix-vector/traits/utilities.hh>
namespace Dune { namespace Dune {
namespace MatrixVector { namespace MatrixVector {
/** \brief Internal helper class for Matrix operations /** \brief Internal helper class for Matrix operations
* *
*/ */
template <class OperatorType, bool isMatrix> template <class T, typename Enable = void>
struct OperatorHelper { struct OperatorHelper;
template <class OperatorType>
struct OperatorHelper<OperatorType, DisableMatrix<OperatorType>> {
template <class VectorType, class VectorType2> template <class VectorType, class VectorType2>
static typename VectorType::field_type Axy(const OperatorType &A, static typename VectorType::field_type Axy(const OperatorType &A,
const VectorType &x, const VectorType &x,
...@@ -36,7 +40,7 @@ namespace MatrixVector { ...@@ -36,7 +40,7 @@ namespace MatrixVector {
}; };
template <class MatrixType> template <class MatrixType>
struct OperatorHelper<MatrixType, true> { struct OperatorHelper<MatrixType, EnableMatrix<MatrixType>> {
template <class VectorType, class VectorType2> template <class VectorType, class VectorType2>
static typename VectorType::field_type Axy(const MatrixType &A, static typename VectorType::field_type Axy(const MatrixType &A,
const VectorType &x, const VectorType &x,
...@@ -90,8 +94,7 @@ namespace MatrixVector { ...@@ -90,8 +94,7 @@ namespace MatrixVector {
typename VectorType::field_type Axy(const OperatorType &A, typename VectorType::field_type Axy(const OperatorType &A,
const VectorType &x, const VectorType &x,
const VectorType2 &y) { const VectorType2 &y) {
return OperatorHelper<OperatorType, return OperatorHelper<OperatorType>::Axy(A, x, y);
MatrixTraits<OperatorType>::isMatrix>::Axy(A, x, y);
} }
//! Compute \f$(b-Ax,y)\f$ //! Compute \f$(b-Ax,y)\f$
...@@ -100,11 +103,10 @@ namespace MatrixVector { ...@@ -100,11 +103,10 @@ namespace MatrixVector {
const VectorType2 &b, const VectorType2 &b,
const VectorType &x, const VectorType &x,
const VectorType2 &y) { const VectorType2 &y) {
return OperatorHelper<OperatorType, return OperatorHelper<OperatorType>::bmAxy(A, b, x, y);
MatrixTraits<OperatorType>::isMatrix>::bmAxy(A, b, x,
y);
}
}
} }
#endif } // end namespace MatrixVector
} // end namespace Dune
#endif // DUNE_MATRIX_VECTOR_AXY_HH
...@@ -9,14 +9,11 @@ ...@@ -9,14 +9,11 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <dune/common/classname.hh>
#include <dune/common/concept.hh>
#include <dune/common/fvector.hh> #include <dune/common/fvector.hh>
#include <dune/common/hybridutilities.hh> #include <dune/common/hybridutilities.hh>
#include <dune/common/typetraits.hh>
#include <dune/matrix-vector/algorithm.hh> #include <dune/matrix-vector/algorithm.hh>
#include <dune/matrix-vector/matrixtraits.hh> #include <dune/matrix-vector/traits/utilities.hh>
//! \brief Various tools for working with istl vectors of arbitrary nesting depth //! \brief Various tools for working with istl vectors of arbitrary nesting depth
namespace Dune { namespace Dune {
...@@ -51,8 +48,7 @@ namespace Impl { ...@@ -51,8 +48,7 @@ namespace Impl {
//! recursion helper for scalars nested in iterables (vectors) //! recursion helper for scalars nested in iterables (vectors)
template <class Vector> template <class Vector>
struct ScalarSwitch<Vector, struct ScalarSwitch<Vector, DisableNumber<Vector>> {
typename std::enable_if_t<not IsNumber<Vector>::value>> {
static void writeBinary(std::ostream& s, const Vector& v) { static void writeBinary(std::ostream& s, const Vector& v) {
Hybrid::forEach(v, [&s](auto&& vi) { Generic::writeBinary(s, vi); }); Hybrid::forEach(v, [&s](auto&& vi) { Generic::writeBinary(s, vi); });
...@@ -72,8 +68,7 @@ struct ScalarSwitch<Vector, ...@@ -72,8 +68,7 @@ struct ScalarSwitch<Vector,
//! recursion anchors for scalars //! recursion anchors for scalars
template <class Scalar> template <class Scalar>
struct ScalarSwitch<Scalar, struct ScalarSwitch<Scalar, EnableNumber<Scalar>> {
typename std::enable_if_t<IsNumber<Scalar>::value>> {
static void writeBinary(std::ostream& s, const Scalar& v) { static void writeBinary(std::ostream& s, const Scalar& v) {
s.write(reinterpret_cast<const char*>(&v), sizeof(Scalar)); s.write(reinterpret_cast<const char*>(&v), sizeof(Scalar));
......
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
#include <dune/matrix-vector/algorithm.hh> #include <dune/matrix-vector/algorithm.hh>
#include <dune/matrix-vector/concepts.hh> #include <dune/matrix-vector/concepts.hh>
#include <dune/matrix-vector/matrixtraits.hh> #include <dune/matrix-vector/traits/utilities.hh>
#include <dune/matrix-vector/traitutilities.hh>
namespace Dune { namespace Dune {
namespace MatrixVector { namespace MatrixVector {
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <dune/matrix-vector/algorithm.hh> #include <dune/matrix-vector/algorithm.hh>
#include <dune/matrix-vector/resize.hh> #include <dune/matrix-vector/resize.hh>
/// custom multitype types to allow for BlockVector nesting /// custom multitype types to allow for BlockVector nesting
template <class... Args> template <class... Args>
struct CustomMultiTypeBlockVector : public Dune::MultiTypeBlockVector<Args...> { struct CustomMultiTypeBlockVector : public Dune::MultiTypeBlockVector<Args...> {
...@@ -28,19 +27,19 @@ struct CustomMultiTypeBlockMatrix : public Dune::MultiTypeBlockMatrix<Args...> { ...@@ -28,19 +27,19 @@ struct CustomMultiTypeBlockMatrix : public Dune::MultiTypeBlockMatrix<Args...> {
constexpr static int blocklevel = 1; // fake value constexpr static int blocklevel = 1; // fake value
}; };
// inject matrix traits for CustomMultiTypeBlockMatrix // inject matrix traits for CustomMultiTypeBlockMatrix
namespace Dune { namespace MatrixVector { namespace Dune { namespace MatrixVector { namespace Traits {
template <class... Args> template <class... Args>
struct MatrixTraits<CustomMultiTypeBlockMatrix<Args...>> { struct MatrixTraits<CustomMultiTypeBlockMatrix<Args...>> {
constexpr static bool isMatrix = true; constexpr static bool isMatrix = true;
}; };
}} }}}
// inject vector identification trait for CustomMultiTypeBlockVector // inject vector identification trait for CustomMultiTypeBlockVector
namespace Dune { namespace MatrixVector { namespace Dune { namespace MatrixVector { namespace Traits {
template <class... Args> template <class... Args>
struct VectorTraits<CustomMultiTypeBlockVector<Args...>> { struct VectorTraits<CustomMultiTypeBlockVector<Args...>> {
constexpr static bool isVector = true; constexpr static bool isVector = true;
}; };
}} }}}
class ResizeTestSuite { class ResizeTestSuite {
......
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
#include <dune/common/diagonalmatrix.hh> #include <dune/common/diagonalmatrix.hh>
#include <dune/common/fmatrix.hh> #include <dune/common/fmatrix.hh>
#include <dune/istl/matrixindexset.hh> #include <dune/istl/matrixindexset.hh>
#include <dune/istl/scaledidmatrix.hh> #include <dune/istl/scaledidmatrix.hh>
#include "axpy.hh" #include <dune/matrix-vector/axpy.hh>
#include "scalartraits.hh" #include <dune/matrix-vector/traits/utilities.hh>
namespace Dune { namespace Dune {
namespace MatrixVector { namespace MatrixVector {
...@@ -255,12 +256,8 @@ namespace MatrixVector { ...@@ -255,12 +256,8 @@ namespace MatrixVector {
const MatrixB& B, const TransformationMatrix2& T2) { const MatrixB& B, const TransformationMatrix2& T2) {
TransformMatrixHelper< TransformMatrixHelper<
MatrixA, TransformationMatrix1, MatrixB, TransformationMatrix2, MatrixA, TransformationMatrix1, MatrixB, TransformationMatrix2,
ScalarTraits<MatrixA>::isScalar, isScalar<MatrixA>(), isScalar<TransformationMatrix1>(), isScalar<MatrixB>(),
ScalarTraits<TransformationMatrix1>::isScalar, isScalar<TransformationMatrix2>()>::addTransformedMatrix(A, T1, B, T2);
ScalarTraits<MatrixB>::isScalar,
ScalarTraits<
TransformationMatrix2>::isScalar>::addTransformedMatrix(A, T1, B,
T2);
} }
template <class MatrixA, class TransformationMatrix1, class MatrixB, template <class MatrixA, class TransformationMatrix1, class MatrixB,
...@@ -270,12 +267,8 @@ namespace MatrixVector { ...@@ -270,12 +267,8 @@ namespace MatrixVector {
A = 0; A = 0;
TransformMatrixHelper< TransformMatrixHelper<
MatrixA, TransformationMatrix1, MatrixB, TransformationMatrix2, MatrixA, TransformationMatrix1, MatrixB, TransformationMatrix2,
ScalarTraits<MatrixA>::isScalar, isScalar<MatrixA>(), isScalar<TransformationMatrix1>(), isScalar<MatrixB>(),
ScalarTraits<TransformationMatrix1>::isScalar, isScalar<TransformationMatrix2>()>::addTransformedMatrix(A, T1, B, T2);
ScalarTraits<MatrixB>::isScalar,
ScalarTraits<
TransformationMatrix2>::isScalar>::addTransformedMatrix(A, T1, B,
T2);
} }
template <class MatrixBlockA, class TransformationMatrix1, class MatrixB, template <class MatrixBlockA, class TransformationMatrix1, class MatrixB,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment