#ifndef DUNE_TNNMG_ITERATIONSTEPS_LINEARCORRECTION_HH #define DUNE_TNNMG_ITERATIONSTEPS_LINEARCORRECTION_HH 1 #include <functional> #include <memory> #include <dune/solvers/iterationsteps/lineariterationstep.hh> #include <dune/solvers/solvers/iterativesolver.hh> #include <dune/solvers/solvers/linearsolver.hh> #include <dune/solvers/common/canignore.hh> #include <dune/solvers/common/resize.hh> #include <dune/solvers/solvers/umfpacksolver.hh> #include "localproblem.hh" /** * \brief linear correction step for use by \ref TNNMGStep * * The function object should solve the linear equation \f$ A x = b \f$. * The ignore bitfield identifies the dofs to be truncated by the * linear solver. The truncation dofs are passed by setting its ignore * member. * Be aware that full rows or columns of `A` might contain only zeroes. */ template<typename Matrix, typename Vector> using LinearCorrection = std::function<void(const Matrix& A, Vector& x, const Vector& b, const Dune::Solvers::DefaultBitVector_t<Vector>& ignore)>; template<typename Vector> Dune::Solvers::DefaultBitVector_t<Vector> emptyIgnore(const Vector& v) { // TNNMGStep assumes that the linearization and the solver for the // linearized problem will not use the ignoreNodes field Dune::Solvers::DefaultBitVector_t<Vector> ignore; Dune::Solvers::resizeInitialize(ignore, v, false); return ignore; } template<typename Matrix, typename Vector> LinearCorrection<Matrix, Vector> makeLinearCorrection(std::shared_ptr< Dune::Solvers::LinearSolver<Matrix, Vector> > linearSolver) { return [=](const Matrix& A, Vector& x, const Vector& b, const Dune::Solvers::DefaultBitVector_t<Vector>& ignore) { auto canIgnoreCast = std::dynamic_pointer_cast<CanIgnore<Dune::Solvers::DefaultBitVector_t<Vector>>>( linearSolver ); if (canIgnoreCast) canIgnoreCast->setIgnore(ignore); else DUNE_THROW(Dune::Exception, "LinearCorrection: linearSolver cannot set ignore member!"); linearSolver->setProblem(A, x, b); linearSolver->preprocess(); linearSolver->solve(); }; } template<typename Matrix, typename Vector> LinearCorrection<Matrix, Vector> makeLinearCorrection(std::shared_ptr< Dune::Solvers::IterativeSolver<Vector> > iterativeSolver) { return [=](const Matrix& A, Vector& x, const Vector& b, const Dune::Solvers::DefaultBitVector_t<Vector>& ignore) { // compute reference solution directly LocalProblem<Matrix, Vector> localProblem(A, b, ignore); Vector newR, directX; Dune::Solvers::resizeInitializeZero(directX, b); localProblem.getLocalRhs(directX, newR); /*print(*this->ignoreNodes_, "ignoreNodes:"); print(A, "A:"); print(localProblem.getMat(), "localMat:");*/ auto directSolver = std::make_shared<Dune::Solvers::UMFPackSolver<Matrix, Vector>>(); directSolver->setProblem(localProblem.getMat(), directX, newR); directSolver->preprocess(); directSolver->solve(); x = directX; /*using LinearIterationStep = Dune::Solvers::LinearIterationStep<Matrix, Vector>; auto linearIterationStep = dynamic_cast<LinearIterationStep*>(&iterativeSolver->getIterationStep()); if (not linearIterationStep) DUNE_THROW(Dune::Exception, "iterative solver must use a linear iteration step"); auto empty = emptyIgnore(x); linearIterationStep->setIgnore(empty); linearIterationStep->setProblem(A, x, b); iterativeSolver->preprocess(); iterativeSolver->solve(); const auto& norm = iterativeSolver->getErrorNorm(); auto error = norm.diff(linearIterationStep->getSol(), directX); std::cout << "Linear solver iterations: " << iterativeSolver->getResult().iterations << " Error: " << error << std::endl;*/ }; } template<typename Matrix, typename Vector> LinearCorrection<Matrix, Vector> makeLinearCorrection(std::shared_ptr< Dune::Solvers::LinearIterationStep<Matrix, Vector> > linearIterationStep, int nIterationSteps = 1) { return [=](const Matrix& A, Vector& x, const Vector& b, const Dune::Solvers::DefaultBitVector_t<Vector>& ignore) { linearIterationStep->setIgnore(ignore); linearIterationStep->setProblem(A, x, b); linearIterationStep->preprocess(); for (int i = 0; i < nIterationSteps; ++i) linearIterationStep->iterate(); }; } #endif