diff --git a/dune/solvers/iterationsteps/blockgssteps.hh b/dune/solvers/iterationsteps/blockgssteps.hh index 86ff72509e30f2fbcd353ba45471d0b634071a65..76cae5f91eeab69d4c6ebf86a4327d472227cfe5 100644 --- a/dune/solvers/iterationsteps/blockgssteps.hh +++ b/dune/solvers/iterationsteps/blockgssteps.hh @@ -240,70 +240,108 @@ auto diagRegularize(double p, LocalSolver&& localSolver) { //! \brief is @LinearSolvers::direct with ignore nodes. //! \param r Regularization parameter. Set to 0.0 to switch off regularization. -template <typename dummy = void> -auto direct(double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) { - return [r](const auto& m, const auto& b, const auto& ignore) { - auto directSolver = LinearSolvers::direct<std::decay_t<decltype(m)>, - std::decay_t<decltype(b)>>; +struct Direct { + Direct(double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) + : r_(r) {} + + template <class M, class V, class I> + V operator()(const M& m, const V& b, const I& ignore) const { + auto directSolver = LinearSolvers::direct<M, V>; auto directTruncatedSolver = LocalSolverFromLinearSolver::truncateSymmetrically(directSolver); auto directTruncatedRegularizedSolver = - LocalSolverRegularizer::diagRegularize(r, directTruncatedSolver); + LocalSolverRegularizer::diagRegularize(r_, directTruncatedSolver); return directTruncatedRegularizedSolver(m, b, ignore); - }; + } + +private: + double r_; +}; + +template <class... Args> +auto direct(Args&&... args) { + return Direct(std::forward<Args>(args)...); } //! \brief is @LinearSolvers::ldlt with ignore nodes. -template <typename dummy = void> -auto ldlt(double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) { - return [r](const auto& m, const auto& b, const auto& ignore) { - auto ldltSolver = LinearSolvers::ldlt<std::decay_t<decltype(m)>, - std::decay_t<decltype(b)>>; +struct LDLt { + LDLt(double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) + : r_(r) {} + + template <class M, class V, class I> + V operator()(const M& m, const V& b, const I& ignore) const { + auto ldltSolver = LinearSolvers::ldlt<M, V>; auto ldltTruncatedSolver = LocalSolverFromLinearSolver::truncateSymmetrically(ldltSolver); auto ldltTruncatedRegularizedSolver = - LocalSolverRegularizer::diagRegularize(r, ldltTruncatedSolver); + LocalSolverRegularizer::diagRegularize(r_, ldltTruncatedSolver); return ldltTruncatedRegularizedSolver(m, b, ignore); - }; + } + +private: + double r_; +}; + +template <class... Args> +auto ldlt(Args&&... args) { + return LDLt(std::forward<Args>(args)...); } //! \brief is @LinearSolvers::cg with ignore nodes. -template <typename dummy = void> -auto cg(size_t maxIter = LinearSolvers::defaultCgMaxIter, - double tol = LinearSolvers::defaultCgTol, - double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) { - return [maxIter, tol, r](const auto& m, const auto& b, const auto& ignore) { +struct CG { + CG(size_t maxIter = LinearSolvers::defaultCgMaxIter, + double tol = LinearSolvers::defaultCgTol, + double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) + : maxIter_(maxIter) + , tol_(tol) + , r_(r) {} + template <class M, class V, class I> + V operator()(const M& m, const V& b, const I& ignore) const { using namespace std::placeholders; - auto cgSolver = std::bind( - LinearSolvers::cg<std::decay_t<decltype(m)>, std::decay_t<decltype(b)>>, - _1, _2, maxIter, tol); - auto cgTruncatedSolver = LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver); + auto cgSolver = std::bind(LinearSolvers::cg<M, V>, _1, _2, maxIter_, tol_); + auto cgTruncatedSolver = + LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver); auto cgTruncatedRegularizedSolver = - LocalSolverRegularizer::diagRegularize(r, cgTruncatedSolver); + LocalSolverRegularizer::diagRegularize(r_, cgTruncatedSolver); return cgTruncatedRegularizedSolver(m, b, ignore); - }; + } + +private: + size_t maxIter_; + double tol_; + double r_; +}; + +template <class... Args> +auto cg(Args&&... args) { + return CG(std::forward<Args>(args)...); } -/** - * \brief is @LinearSolvers::gs with ignore nodes. - * \param tol Tolerance value for skipping potentially zero diagonal entries - * after regularization. - * \param r Regularization parameter. Set to 0.0 to switch off regularization. - */ -template <typename dummy = void> -auto gs(double tol = LinearSolvers::defaultGsTol, - double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) { - return [tol, r](const auto& m, const auto& b, const auto& ignore) { +struct GS { + GS(double tol = LinearSolvers::defaultGsTol, + double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) + : tol_(tol) + , r_(r) {} + + template <class M, class V, class I> + V operator()(const M& m, const V& b, const I& ignore) const { using namespace std::placeholders; - auto gsSolver = std::bind( - LinearSolvers::gs<std::decay_t<decltype(m)>, std::decay_t<decltype(b)>>, - _1, _2, tol); + auto gsSolver = std::bind(LinearSolvers::gs<M, V>, _1, _2, tol_); auto gsTruncatedSolver = LocalSolverFromLinearSolver::truncateSymmetrically(gsSolver); auto gsTruncatedRegularizedSolver = - LocalSolverRegularizer::diagRegularize(r, gsTruncatedSolver); + LocalSolverRegularizer::diagRegularize(r_, gsTruncatedSolver); return gsTruncatedRegularizedSolver(m, b, ignore); - }; + } + +private: + double tol_; + double r_; +}; + +template <class... Args> +auto gs(Args&&... args) { + return GS(std::forward<Args>(args)...); } } // end namespace LocalSolvers @@ -318,12 +356,25 @@ template <class LocalSolver, class Matrix, class Vector, class BitVector = Dune::Solvers::DefaultBitVector_t<Vector>> struct BlockGSStep : public LinearIterationStep<Matrix, Vector, BitVector> { - template<class LS> + //! \brief Implicitly default-construct local solver + BlockGSStep(BlockGS::Direction direction = BlockGS::Direction::FORWARD) + : localSolver_() + , direction_(direction) {} + + //! \brief Use given local solver instance + template <class LS> BlockGSStep(LS&& localSolver, BlockGS::Direction direction = BlockGS::Direction::FORWARD) : localSolver_(std::forward<LS>(localSolver)) , direction_(direction) {} + //! Implicitly construct local solver with forwarded arguments + //! \note Gauss--Seidel direction is hardwired to 'forward' here! + template <class... Args> + BlockGSStep(Args&&... args) + : localSolver_(std::forward<Args>(args)...) + , direction_(BlockGS::Direction::FORWARD) {} + void iterate() { assert(this->mat_ != nullptr); assert(this->x_ != nullptr); diff --git a/dune/solvers/test/gssteptest.cc b/dune/solvers/test/gssteptest.cc index fa283ba306bdedc77188719afc5c2616b4b06054..2296fa9d257afdedf08e13562a8aa63c13be629b 100644 --- a/dune/solvers/test/gssteptest.cc +++ b/dune/solvers/test/gssteptest.cc @@ -164,6 +164,25 @@ struct GSTestSuite { } } + // Test constructor calls + { + using Dune::Solvers::BlockGSStep; + using namespace Dune::Solvers::BlockGS::LocalSolvers; + using Ignore = Dune::Solvers::DefaultBitVector_t<Vector>; + + BlockGSStep<Direct, Matrix, Vector, Ignore> step1; + BlockGSStep<LDLt, Matrix, Vector, Ignore> step2; + BlockGSStep<CG, Matrix, Vector, Ignore> step3; + BlockGSStep<GS, Matrix, Vector, Ignore> step4; + + auto tol = 1e-14; + auto r = 1e-10; + BlockGSStep<Direct, Matrix, Vector, Ignore> step1a(r); + BlockGSStep<LDLt, Matrix, Vector, Ignore> step2a(r); + BlockGSStep<CG, Matrix, Vector, Ignore> step3a(5, tol, r); + BlockGSStep<GS, Matrix, Vector, Ignore> step4a(tol, r); + } + // test projected block GS if (trivialDirichletOnly) // TODO: missing feature in ProjectedBlockGS {