diff --git a/dune/subgrid/common/variant.hh b/dune/subgrid/common/variant.hh index 6f57b8b265f7ebb838367e00686f62666e4029a0..e2132ea6c66de8266669eb87f8b9a44fe6143837 100644 --- a/dune/subgrid/common/variant.hh +++ b/dune/subgrid/common/variant.hh @@ -70,7 +70,6 @@ namespace Impl { template<typename Tp> struct TypeStorage_<Tp, false> { TypeStorage_(Tp t) { - //tp_ = ::new Tp(t); ::new (&tp_) Tp(t); } @@ -106,18 +105,6 @@ namespace Impl { constexpr variant_union_(std::integral_constant<size_t, N>, Args&&... args) : tail_(std::integral_constant<size_t, N-1>(), std::forward<Args...>(args)...) {} - // TODO: This should not be a copy! However, if I return by reference, - // compiler tells me it can not return an temporary as a non-const ref. (which is of course - // true, but I don't see why this is a temporary?). Then again, one does not need this - // function anyway. Probably one should just drop it. - template<typename Tp> - auto getByType() { - return Dune::Hybrid::ifElse(std::is_same<Tp, Head_>(), - [this](auto) { return this->head_.get();}, - [this](auto id) { return id(this->tail_).template getByType<Tp>();} - ); - } - auto& getByIndex(std::integral_constant<size_t, 0>) { return head_.get(); } @@ -171,21 +158,65 @@ namespace Impl { template<typename Tp> auto& get() { constexpr size_t idx = index_in_pack<Tp, T...>::value; + if (index_ != idx) + DUNE_THROW(Dune::Exception, "Bad variant access."); + return get<idx>(); } template<typename Tp> const auto& get() const { constexpr size_t idx = index_in_pack<Tp, T...>::value; + if (index_ != idx) + DUNE_THROW(Dune::Exception, "Bad variant access."); + return get<idx>(); } + template<typename Tp> + Tp* get_if() { + if (not holds_alternative<Tp>()) + return (Tp*) nullptr; + else + return &(get<Tp>()); + } + + template<typename Tp> + const Tp* get_if() const { + if (not holds_alternative<Tp>()) + return (Tp*) nullptr; + else + return &(get<Tp>()); + } + + template<size_t N> + auto* get_if() { + using Tp = std::decay_t<decltype(get<N>())>; + if (not holds_alternative<N>()) + return (Tp*) nullptr; + else + return &(get<Tp>()); + } + + template<size_t N> + const auto* get_if() const { + using Tp = std::decay_t<decltype(get<N>())>; + if (not holds_alternative<N>()) + return (Tp*) nullptr; + else + return &(get<Tp>()); + } + template<size_t N> auto& get() { + if (index_ != N) + DUNE_THROW(Dune::Exception, "Bad variant access."); return unions_.template getByIndex(std::integral_constant<size_t, N>()); } template<size_t N> const auto& get() const { + if (index_ != N) + DUNE_THROW(Dune::Exception, "Bad variant access."); return unions_.template getByIndex(std::integral_constant<size_t, N>()); } @@ -197,7 +228,7 @@ namespace Impl { return unions_.getByIndex(std::integral_constant<size_t,index>()); } - constexpr std::size_t index() const { + constexpr std::size_t index() const noexcept { return index_; } @@ -271,6 +302,14 @@ namespace Impl { return (index_in_pack<Tp, T...>::value == index_); } + /** \brief Check if a given type is the one that is currently active in the variant. */ + template<size_t N> + constexpr bool holds_alternative() const { + // I have no idea how this could be really constexpr, but for STL-conformity, + // I'll leave the modifier there. + return (N == index_); + } + private: variant_union_<T...> unions_; std::size_t index_; @@ -313,6 +352,26 @@ namespace Impl { return var.template get<Tp>(); } + template<typename Tp, typename ...T> + const auto* get_if(const variant<T...>& var) { + return var.template get_if<Tp>(); + } + + template<typename Tp, typename ...T> + auto* get_if(variant<T...>& var) { + return var.template get_if<Tp>(); + } + + template<size_t N, typename ...T> + const auto* get_if(const variant<T...>& var) { + return var.template get_if<N>(); + } + + template<size_t N, typename ...T> + auto* get_if(variant<T...>& var) { + return var.template get_if<N>(); + } + template<typename Tp, typename ...T> constexpr bool holds_alternative(const variant<T...>& var) { return var.template holds_alternative<Tp>(); diff --git a/dune/subgrid/test/testvariant.cc b/dune/subgrid/test/testvariant.cc index d5a7d077ac3f29958c3fee7477a0bc94a09d43da..a05dc440d3db85d16093bfb1ee6e98848090557d 100644 --- a/dune/subgrid/test/testvariant.cc +++ b/dune/subgrid/test/testvariant.cc @@ -9,7 +9,7 @@ #include <dune/common/exceptions.hh> // We use exceptions #include <dune/common/test/testsuite.hh> -#include <dune/subgrid/common/variant.hh> +#include <dune/common/std/variant.hh> // some non-default constructible type struct F { @@ -17,11 +17,6 @@ struct F { F() = delete; F(int j) : i(j) {} - - F& operator *=(int factor) { - i*=factor; - return *this; - } }; Dune::TestSuite testVariant() { @@ -35,41 +30,58 @@ Dune::TestSuite testVariant() { auto variant = Std::variant<int, double, F, V>(); - suite.check(Std::variant_size_v(variant) == 4); + suite.check(Std::variant_size_v(variant) == 4, "Test variant_size_v"); variant = d; - suite.check(Std::holds_alternative<double>(variant)); + suite.check(Std::holds_alternative<double>(variant), "Test holds_alternative"); variant = f; - suite.check(Std::holds_alternative<F>(variant)); + suite.check(Std::holds_alternative<F>(variant), "Test holds_alternative"); variant = i; - suite.check(Std::holds_alternative<int>(variant)); - // TODO: actual compare operators - suite.check(Std::get<int>(variant) == i); - suite.check(Std::get<0>(variant) == i); + suite.check(Std::holds_alternative<int>(variant), "Test holds_alternative"); + + suite.check(Std::get<int>(variant) == i, "Test get<Type>"); + suite.check(Std::get<0>(variant) == i, "Test get<Index>"); + + suite.check(Std::get_if<int>(variant) != nullptr, "Test get_if on right type"); + suite.check(Std::get_if<double>(variant) == nullptr, "Test get_if on wrong type"); + + suite.check(Std::get_if<0>(variant) != nullptr, "Test get_if on right index"); + suite.check(Std::get_if<1>(variant) == nullptr, "Test get_if on wrong index"); + + // test if get<Type> throws if one tries to get the wrong type: + try { + // currently hold type is still int, so double should throw + Std::get<double>(variant); + suite.check(false, "Test get<Type> on wrong type should have thrown"); + } + catch (...) { + suite.check(true, "Test get<Type> on wrong type has thrown"); + } variant = V(1); - suite.check(Std::get<V>(variant).size() == 1); + suite.check(Std::get<V>(variant).size() == 1, "Test with non-trivial type"); variant = f; - suite.check(variant.index() == 2); // we're at type F, which has position 2 + suite.check(variant.index() == 2, "Test index()"); // we're at type F, which has position 2 - // Demonstrate visit concept: + // Demonstrate visit concept and using vector as an example of a non-POD type using V2 = std::vector<double>; Std::variant<V, V2> variant2; variant2 = V(1); auto size = [](auto&& v) {return v.size();}; - suite.check(Std::visit(size, variant2)== 1); + suite.check(Std::visit(size, variant2)== 1, "Test visit"); variant2 = V2(2); - suite.check(Std::visit(size, variant2)== 2); + suite.check(Std::visit(size, variant2)== 2, "Test visit"); // try on a const reference: const auto& constv2 = variant2; - suite.check(Std::visit(size, constv2)== 2); + suite.check(Std::visit(size, constv2)== 2, "Test const visit"); + suite.check(Std::get_if<V2>(constv2) != nullptr, "Test const get_if"); return suite;