Skip to content
Snippets Groups Projects
Commit b6129c2c authored by Lasse Hinrichsen's avatar Lasse Hinrichsen
Browse files

Introduce more general "applyEntrywise"

There is no need to restrict oneself to binary operators.
parent 58ce5c15
Branches
No related tags found
1 merge request!9[WIP] Add applyEntrywise method
...@@ -30,14 +30,14 @@ struct HasSize { ...@@ -30,14 +30,14 @@ struct HasSize {
auto require(C&& c) -> decltype(c.size()); auto require(C&& c) -> decltype(c.size());
}; };
struct IsLegalBinaryOperator { struct CanBeApplied {
template <class C, class D, class Op> template <class C, class Op, class... D>
auto require(C&& c, D&& d, Op&& op) -> decltype(op(c,d)); auto require(C&& c, Op&& op, D&&... d) -> decltype(c=op(d...));
}; };
template <class C, class D, class Op> template <class C, class Op, class... D>
constexpr bool isLegalBinaryOperator(C&& c, D&& d, Op&& op) { constexpr auto canBeApplied(C&& c, Op&& op, D&&... d) {
return models<IsLegalBinaryOperator, C, D, Op>(); return models<CanBeApplied, C, Op, D...>();
} }
} // end namespace Concept } // end namespace Concept
......
...@@ -45,20 +45,21 @@ void truncate(Vector& v, const BitVector& tr) { ...@@ -45,20 +45,21 @@ void truncate(Vector& v, const BitVector& tr) {
Impl::ScalarSwitch<Vector>::truncate(v, tr); Impl::ScalarSwitch<Vector>::truncate(v, tr);
} }
/**! For a given operator "op", compute x = op(x,y) ENTRYWISE.
/**! For a given operator "op", compute x = op(y...) ENTRYWISE.
* *
* If the op is not defined for the given types, * If the op is not defined for the given types,
* the operator is recursively applied to subblocks * the operator is recursively applied to subblocks
* until it can be used. * until it can be used.
*/ */
template<typename X, typename Y, typename BinaryOp> template<typename X, typename Op, typename... Y>
void inplaceBinary(X&& x, Y&& y, BinaryOp&& op) { void applyEntrywise(X&& x, Op&& op, Y&&... y) {
if constexpr(Concept::isLegalBinaryOperator(x,y,op)) { if constexpr(decltype(Concept::canBeApplied(x, op, y...))::value) {
x = op(x,y); x = op(y...);
} }
else { else {
for (std::size_t i = 0; i < x.size(); ++i) { for (std::size_t i = 0; i < x.size(); ++i) {
inplaceBinary(x[i], y[i], op); applyEntrywise(x[i], op, y[i]...);
} }
} }
} }
......
...@@ -36,10 +36,11 @@ TestSuite test_inplace_or() { ...@@ -36,10 +36,11 @@ TestSuite test_inplace_or() {
// xx and yy, because we explicitly check if the given input is // xx and yy, because we explicitly check if the given input is
// bool (or something convertible) and in that case apply // bool (or something convertible) and in that case apply
// the operator // the operator
auto my_or = [](bool xx, bool yy) { return xx || yy; }; auto my_or = [](bool xx, bool yy) -> bool
{ return xx || yy; };
auto inplace_or = [&](auto& x, const auto& y) { auto inplace_or = [&](auto& x, const auto& y) {
inplaceBinary(x,y, my_or); applyEntrywise(x, my_or, x, y);
}; };
// check bools (trivial case) // check bools (trivial case)
...@@ -108,13 +109,64 @@ TestSuite test_inplace_or() { ...@@ -108,13 +109,64 @@ TestSuite test_inplace_or() {
return suite; return suite;
} }
TestSuite test_plus() {
TestSuite suite;
using namespace Dune::MatrixVector::Generic;
// we mimic x = y+z
auto my_plus = std::plus<double>{};
auto inplace_plus = [&](auto& x, const auto& y, const auto& z) { applyEntrywise(x, my_plus, y, z); };
// test double
{
// both values should be exact in binary representation
auto x = 0;
auto y = 2.5;
auto z = 1.5;
inplace_plus(x, y, z); // x = y + z
suite.check(x== 4., "inplace_plus failed for double");
}
// test vector<double>
{
auto x = std::vector<double>(10);
auto y = std::vector<double>(10, 2.5);
auto z = std::vector<double>(10, 1.5);
inplace_plus(x, y, z); // x = y + z
for (const auto& entry : x)
suite.check(entry == 4.0, "inplace_plus failed for vector<double> entry.");
}
// test BlockVector<FieldVector<double, 2>>
{
auto x = BlockVector<FieldVector<double, 2>>(10);
auto y = BlockVector<FieldVector<double, 2>>(10);
auto z = BlockVector<FieldVector<double, 2>>(10);
y = 2.5;
z = 1.5;
inplace_plus(x, y, z); // x = y + z
for (const auto& block : x)
for (const auto& entry : block)
suite.check(entry == 4.0, "inplace_plus failed for BlockVector<FieldVector<2>> entry.");
}
return suite;
}
TestSuite test_inplace_plus() { TestSuite test_inplace_plus() {
TestSuite suite; TestSuite suite;
using namespace Dune::MatrixVector::Generic; using namespace Dune::MatrixVector::Generic;
// we mimic x+=y
auto my_plus = std::plus<double>{}; auto my_plus = std::plus<double>{};
auto inplace_plus = [&](auto& x, const auto& y) { inplaceBinary(x, y, my_plus); }; auto inplace_plus = [&](auto& x, const auto& y) { applyEntrywise(x, my_plus, x, y); };
// test double // test double
{ {
...@@ -153,12 +205,12 @@ TestSuite test_inplace_plus() { ...@@ -153,12 +205,12 @@ TestSuite test_inplace_plus() {
return suite; return suite;
} }
TestSuite test_inplaceBinary() { TestSuite test_applyEntrywise() {
TestSuite suite; TestSuite suite;
// we test two example operators, the || and the + operator.
suite.subTest(test_inplace_or()); suite.subTest(test_inplace_or());
suite.subTest(test_inplace_plus()); suite.subTest(test_inplace_plus());
suite.subTest(test_plus());
return suite; return suite;
} }
...@@ -171,7 +223,7 @@ int main(int argc, char* argv[]) { ...@@ -171,7 +223,7 @@ int main(int argc, char* argv[]) {
// TODO: There are more things in genericvectortools // TODO: There are more things in genericvectortools
// that need tests! // that need tests!
suite.subTest(test_inplaceBinary()); suite.subTest(test_applyEntrywise());
return suite.exit();; return suite.exit();;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment