From 58ce5c15ed73a2b0e973b29343ee8dd3d7830155 Mon Sep 17 00:00:00 2001 From: Lasse Hinrichsen <l.hinrichsen@fu-berlin.de> Date: Mon, 6 Jul 2020 14:59:35 +0200 Subject: [PATCH] Add inplaceBinary method With this tools, one can write quite general code to apply x = op(x,y) entrywise for vectors x,y. --- dune/matrix-vector/concepts.hh | 10 + dune/matrix-vector/genericvectortools.hh | 20 ++ dune/matrix-vector/test/CMakeLists.txt | 1 + .../test/genericvectortoolstest.cc | 177 ++++++++++++++++++ 4 files changed, 208 insertions(+) create mode 100644 dune/matrix-vector/test/genericvectortoolstest.cc diff --git a/dune/matrix-vector/concepts.hh b/dune/matrix-vector/concepts.hh index c854388..79293a9 100644 --- a/dune/matrix-vector/concepts.hh +++ b/dune/matrix-vector/concepts.hh @@ -30,6 +30,16 @@ struct HasSize { auto require(C&& c) -> decltype(c.size()); }; +struct IsLegalBinaryOperator { + template <class C, class D, class Op> + auto require(C&& c, D&& d, Op&& op) -> decltype(op(c,d)); +}; + +template <class C, class D, class Op> +constexpr bool isLegalBinaryOperator(C&& c, D&& d, Op&& op) { + return models<IsLegalBinaryOperator, C, D, Op>(); +} + } // end namespace Concept } // end namespace MatrixVector } // end namespace Dune diff --git a/dune/matrix-vector/genericvectortools.hh b/dune/matrix-vector/genericvectortools.hh index 067835d..dd5a61c 100644 --- a/dune/matrix-vector/genericvectortools.hh +++ b/dune/matrix-vector/genericvectortools.hh @@ -13,6 +13,7 @@ #include <dune/common/hybridutilities.hh> #include <dune/matrix-vector/algorithm.hh> +#include <dune/matrix-vector/concepts.hh> #include <dune/matrix-vector/traits/utilities.hh> //! \brief Various tools for working with istl vectors of arbitrary nesting depth @@ -44,6 +45,25 @@ void truncate(Vector& v, const BitVector& tr) { Impl::ScalarSwitch<Vector>::truncate(v, tr); } +/**! For a given operator "op", compute x = op(x,y) ENTRYWISE. + * + * If the op is not defined for the given types, + * the operator is recursively applied to subblocks + * until it can be used. + */ +template<typename X, typename Y, typename BinaryOp> +void inplaceBinary(X&& x, Y&& y, BinaryOp&& op) { + if constexpr(Concept::isLegalBinaryOperator(x,y,op)) { + x = op(x,y); + } + else { + for (std::size_t i = 0; i < x.size(); ++i) { + inplaceBinary(x[i], y[i], op); + } + } +} + + namespace Impl { //! recursion helper for scalars nested in iterables (vectors) diff --git a/dune/matrix-vector/test/CMakeLists.txt b/dune/matrix-vector/test/CMakeLists.txt index 4d2e82f..83ed102 100644 --- a/dune/matrix-vector/test/CMakeLists.txt +++ b/dune/matrix-vector/test/CMakeLists.txt @@ -1,4 +1,5 @@ dune_add_test(SOURCES arithmetictest.cc) +dune_add_test(SOURCES genericvectortoolstest.cc) dune_add_test(SOURCES resizetest.cc) dune_add_test(SOURCES staticmatrixtoolstest.cc) dune_add_test(SOURCES triangularsolvetest.cc) diff --git a/dune/matrix-vector/test/genericvectortoolstest.cc b/dune/matrix-vector/test/genericvectortoolstest.cc new file mode 100644 index 0000000..b113936 --- /dev/null +++ b/dune/matrix-vector/test/genericvectortoolstest.cc @@ -0,0 +1,177 @@ +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include <cstddef> +#include <vector> + +#include <dune/common/bitsetvector.hh> +#include <dune/common/parallel/mpihelper.hh> +#include <dune/common/test/testsuite.hh> +#include <dune/common/fvector.hh> + +#include <dune/istl/bvector.hh> + +#include <dune/matrix-vector/genericvectortools.hh> + +using namespace Dune; + +//template<typename X, typename Y> +//bool my_or(const X& x, const Y& y) { + //return x || y; +//} + +TestSuite test_inplace_or() { + TestSuite suite; + + using namespace Dune::MatrixVector::Generic; + + // This would be the STL approach + //auto inplace_or = [](auto& x, const auto& y) { + //inplaceBinary(x,y, std::logical_or<void>()); + //}; + + // For demonstration, let'ss cook something up. + // Note that we cant use "auto" for the types of + // xx and yy, because we explicitly check if the given input is + // bool (or something convertible) and in that case apply + // the operator + auto my_or = [](bool xx, bool yy) { return xx || yy; }; + + auto inplace_or = [&](auto& x, const auto& y) { + inplaceBinary(x,y, my_or); + }; + + // check bools (trivial case) + { + bool t = true; + bool f = false; + + inplace_or(t,f); // t |= f + suite.check(t == true); + + inplace_or(f,t); // f |= t + suite.check(f == true); + + f=false; + inplace_or(f,f); + suite.check(f == false); + + inplace_or(t,t); + suite.check(t == true); + } + + // check BitSetVector + constexpr const int N = 3; + using BSV = BitSetVector<N>; + { + BSV x(2, false); + BSV y(2, true); + + inplace_or(x,y); + for (std::size_t i = 0; i < x.size(); ++i) { + suite.check(x[i].all() == true); + } + } + + // check vector<BSV> + { + auto x = std::vector<BSV>(4); + auto y = std::vector<BSV>(4); + + for (auto&& xx : x) + xx = BSV(2, false); + for (auto&& yy : y) + yy = BSV(2, true); + + inplace_or(x, y); + + for (const auto& bsv : x) { + for (const auto& bs: bsv) { + suite.check(bs.all()); + } + } + + } + + // mix types + { + auto x = std::vector<bool>(10, false); + auto y = std::vector<char>(10, true); + + inplace_or(x,y); + + for(const auto& xi : x) + suite.check(xi == true); + } + + return suite; +} + +TestSuite test_inplace_plus() { + TestSuite suite; + + using namespace Dune::MatrixVector::Generic; + + auto my_plus = std::plus<double>{}; + auto inplace_plus = [&](auto& x, const auto& y) { inplaceBinary(x, y, my_plus); }; + + // test double + { + // both values should be exact in binary representation + auto x = 1.5; + auto y = 2.5; + + inplace_plus(x,y); // x += y + suite.check(x== 4., "inplace_plus failed for double"); + } + + // test vector<double> + { + auto x = std::vector<double>(10, 1.5); + auto y = std::vector<double>(10, 2.5); + + inplace_plus(x,y); // x += y + for (const auto& entry : x) + suite.check(entry == 4.0, "inplace_plus failed for vector<double> entry."); + } + + // test BlockVector<FieldVector<double, 2>> + { + auto x = BlockVector<FieldVector<double, 2>>(10); + auto y = BlockVector<FieldVector<double, 2>>(10); + + x = 1.5; + y = 2.5; + + inplace_plus(x,y); // x += y + for (const auto& block : x) + for (const auto& entry : block) + suite.check(entry == 4.0, "inplace_plus failed for BlockVector<FieldVector<2>> entry."); + } + + + return suite; +} +TestSuite test_inplaceBinary() { + TestSuite suite; + + // we test two example operators, the || and the + operator. + suite.subTest(test_inplace_or()); + suite.subTest(test_inplace_plus()); + return suite; +} + +int main(int argc, char* argv[]) { + MPIHelper::instance(argc, argv); + + + TestSuite suite; + + // TODO: There are more things in genericvectortools + // that need tests! + + suite.subTest(test_inplaceBinary()); + + return suite.exit();; +} -- GitLab