From 58ce5c15ed73a2b0e973b29343ee8dd3d7830155 Mon Sep 17 00:00:00 2001
From: Lasse Hinrichsen <l.hinrichsen@fu-berlin.de>
Date: Mon, 6 Jul 2020 14:59:35 +0200
Subject: [PATCH] Add inplaceBinary method

With this tools, one can write quite general code to apply
x = op(x,y)
entrywise for vectors x,y.
---
 dune/matrix-vector/concepts.hh                |  10 +
 dune/matrix-vector/genericvectortools.hh      |  20 ++
 dune/matrix-vector/test/CMakeLists.txt        |   1 +
 .../test/genericvectortoolstest.cc            | 177 ++++++++++++++++++
 4 files changed, 208 insertions(+)
 create mode 100644 dune/matrix-vector/test/genericvectortoolstest.cc

diff --git a/dune/matrix-vector/concepts.hh b/dune/matrix-vector/concepts.hh
index c854388..79293a9 100644
--- a/dune/matrix-vector/concepts.hh
+++ b/dune/matrix-vector/concepts.hh
@@ -30,6 +30,16 @@ struct HasSize {
   auto require(C&& c) -> decltype(c.size());
 };
 
+struct IsLegalBinaryOperator {
+  template <class C, class D, class Op>
+  auto require(C&& c, D&& d, Op&& op) -> decltype(op(c,d));
+};
+
+template <class C, class D, class Op>
+constexpr bool isLegalBinaryOperator(C&& c, D&& d, Op&& op) {
+  return models<IsLegalBinaryOperator, C, D, Op>();
+}
+
 } // end namespace Concept
 } // end namespace MatrixVector
 } // end namespace Dune
diff --git a/dune/matrix-vector/genericvectortools.hh b/dune/matrix-vector/genericvectortools.hh
index 067835d..dd5a61c 100644
--- a/dune/matrix-vector/genericvectortools.hh
+++ b/dune/matrix-vector/genericvectortools.hh
@@ -13,6 +13,7 @@
 #include <dune/common/hybridutilities.hh>
 
 #include <dune/matrix-vector/algorithm.hh>
+#include <dune/matrix-vector/concepts.hh>
 #include <dune/matrix-vector/traits/utilities.hh>
 
 //! \brief Various tools for working with istl vectors of arbitrary nesting depth
@@ -44,6 +45,25 @@ void truncate(Vector& v, const BitVector& tr) {
   Impl::ScalarSwitch<Vector>::truncate(v, tr);
 }
 
+/**! For a given operator "op", compute x = op(x,y) ENTRYWISE.
+ *
+ * If the op is not defined for the given types,
+ * the operator is recursively applied to subblocks
+ * until it can be used.
+ */
+template<typename X, typename Y, typename BinaryOp>
+void inplaceBinary(X&&  x, Y&& y, BinaryOp&& op) {
+  if constexpr(Concept::isLegalBinaryOperator(x,y,op)) {
+    x = op(x,y);
+  }
+  else {
+    for (std::size_t i = 0; i < x.size(); ++i) {
+      inplaceBinary(x[i], y[i], op);
+    }
+  }
+}
+
+
 namespace Impl {
 
 //! recursion helper for scalars nested in iterables (vectors)
diff --git a/dune/matrix-vector/test/CMakeLists.txt b/dune/matrix-vector/test/CMakeLists.txt
index 4d2e82f..83ed102 100644
--- a/dune/matrix-vector/test/CMakeLists.txt
+++ b/dune/matrix-vector/test/CMakeLists.txt
@@ -1,4 +1,5 @@
 dune_add_test(SOURCES arithmetictest.cc)
+dune_add_test(SOURCES genericvectortoolstest.cc)
 dune_add_test(SOURCES resizetest.cc)
 dune_add_test(SOURCES staticmatrixtoolstest.cc)
 dune_add_test(SOURCES triangularsolvetest.cc)
diff --git a/dune/matrix-vector/test/genericvectortoolstest.cc b/dune/matrix-vector/test/genericvectortoolstest.cc
new file mode 100644
index 0000000..b113936
--- /dev/null
+++ b/dune/matrix-vector/test/genericvectortoolstest.cc
@@ -0,0 +1,177 @@
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include <cstddef>
+#include <vector>
+
+#include <dune/common/bitsetvector.hh>
+#include <dune/common/parallel/mpihelper.hh>
+#include <dune/common/test/testsuite.hh>
+#include <dune/common/fvector.hh>
+
+#include <dune/istl/bvector.hh>
+
+#include <dune/matrix-vector/genericvectortools.hh>
+
+using namespace Dune;
+
+//template<typename X, typename Y>
+//bool my_or(const X& x, const Y& y) {
+  //return x || y;
+//}
+
+TestSuite test_inplace_or() {
+  TestSuite suite;
+
+  using namespace Dune::MatrixVector::Generic;
+
+  // This would be the STL approach
+  //auto inplace_or = [](auto& x, const auto& y) {
+    //inplaceBinary(x,y, std::logical_or<void>());
+  //};
+
+  // For demonstration, let'ss cook something up.
+  // Note that we cant use "auto" for the types of
+  // xx and yy, because we explicitly check if the given input is
+  // bool (or something convertible) and in that case apply
+  // the operator
+  auto my_or = [](bool xx, bool yy) { return xx || yy; };
+
+  auto inplace_or = [&](auto& x, const auto& y) {
+    inplaceBinary(x,y, my_or);
+  };
+
+  // check bools (trivial case)
+  {
+    bool t = true;
+    bool f = false;
+
+    inplace_or(t,f); // t |= f
+    suite.check(t == true);
+
+    inplace_or(f,t); // f |= t
+    suite.check(f == true);
+
+    f=false;
+    inplace_or(f,f);
+    suite.check(f == false);
+
+    inplace_or(t,t);
+    suite.check(t == true);
+  }
+
+  // check BitSetVector
+  constexpr const int N = 3;
+  using BSV = BitSetVector<N>;
+  {
+    BSV x(2, false);
+    BSV y(2, true);
+
+    inplace_or(x,y);
+    for (std::size_t i = 0; i < x.size(); ++i) {
+      suite.check(x[i].all() == true);
+    }
+  }
+
+  // check vector<BSV>
+  {
+    auto x = std::vector<BSV>(4);
+    auto y = std::vector<BSV>(4);
+
+    for (auto&& xx : x)
+      xx = BSV(2, false);
+    for (auto&& yy : y)
+      yy = BSV(2, true);
+
+    inplace_or(x, y);
+
+    for (const auto& bsv : x) {
+      for (const auto& bs: bsv) {
+        suite.check(bs.all());
+      }
+    }
+
+  }
+
+  // mix types
+  {
+    auto x = std::vector<bool>(10, false);
+    auto y = std::vector<char>(10, true);
+
+    inplace_or(x,y);
+
+    for(const auto& xi : x)
+      suite.check(xi == true);
+  }
+
+  return suite;
+}
+
+TestSuite test_inplace_plus() {
+  TestSuite suite;
+
+  using namespace Dune::MatrixVector::Generic;
+
+  auto my_plus = std::plus<double>{};
+  auto inplace_plus = [&](auto& x, const auto& y) { inplaceBinary(x, y, my_plus); };
+
+  // test double
+  {
+    // both values should be exact in binary representation
+    auto x = 1.5;
+    auto y = 2.5;
+
+    inplace_plus(x,y); // x += y
+    suite.check(x== 4., "inplace_plus failed for double");
+  }
+
+  // test vector<double>
+  {
+    auto x = std::vector<double>(10, 1.5);
+    auto y = std::vector<double>(10, 2.5);
+
+    inplace_plus(x,y); // x += y
+    for (const auto& entry : x)
+      suite.check(entry == 4.0, "inplace_plus failed for vector<double> entry.");
+  }
+
+  // test BlockVector<FieldVector<double, 2>>
+  {
+    auto x = BlockVector<FieldVector<double, 2>>(10);
+    auto y = BlockVector<FieldVector<double, 2>>(10);
+
+    x = 1.5;
+    y = 2.5;
+
+    inplace_plus(x,y); // x += y
+    for (const auto& block : x)
+      for (const auto& entry : block)
+      suite.check(entry == 4.0, "inplace_plus failed for BlockVector<FieldVector<2>> entry.");
+  }
+
+
+  return suite;
+}
+TestSuite test_inplaceBinary() {
+  TestSuite suite;
+
+  // we test two example operators, the || and the + operator.
+  suite.subTest(test_inplace_or());
+  suite.subTest(test_inplace_plus());
+  return suite;
+}
+
+int main(int argc, char* argv[]) {
+  MPIHelper::instance(argc, argv);
+
+
+  TestSuite suite;
+
+  // TODO: There are more things in genericvectortools
+  // that need tests!
+
+  suite.subTest(test_inplaceBinary());
+
+  return suite.exit();;
+}
-- 
GitLab