Skip to content
Snippets Groups Projects
Forked from agnumpde / dune-tectonic
148 commits ahead of the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
linearcorrection.hh 4.22 KiB
#ifndef DUNE_TECTONIC_ITERATIONSTEPS_LINEARCORRECTION_HH
#define DUNE_TECTONIC_ITERATIONSTEPS_LINEARCORRECTION_HH

#include <functional>
#include <memory>

#include <dune/solvers/iterationsteps/lineariterationstep.hh>
#include <dune/solvers/solvers/iterativesolver.hh>
#include <dune/solvers/solvers/linearsolver.hh>
#include <dune/solvers/common/canignore.hh>

#include <dune/solvers/common/resize.hh>
#include <dune/solvers/solvers/umfpacksolver.hh>

/**
 * \brief linear correction step for use by \ref TNNMGStep
 *
 * The function object should solve the linear equation \f$ A x = b \f$.
 * The ignore bitfield identifies the dofs to be truncated by the
 * linear solver. The truncation dofs are passed by setting its ignore
 * member.
 * Be aware that full rows or columns of `A` might contain only zeroes.
 */
template<typename Matrix, typename Vector>
using LinearCorrection = std::function<void(const Matrix& A, Vector& x, const Vector& b, const Dune::Solvers::DefaultBitVector_t<Vector>& ignore)>;

template<typename Vector>
Dune::Solvers::DefaultBitVector_t<Vector>
emptyIgnore(const Vector& v)
{
  // TNNMGStep assumes that the linearization and the solver for the
  // linearized problem will not use the ignoreNodes field
  Dune::Solvers::DefaultBitVector_t<Vector> ignore;
  Dune::Solvers::resizeInitialize(ignore, v, false);
  return ignore;
}

template<typename Matrix, typename Vector>
LinearCorrection<Matrix, Vector>
buildLinearCorrection(std::shared_ptr< Dune::Solvers::LinearSolver<Matrix, Vector> > linearSolver)
{
  return [=](const Matrix& A, Vector& x, const Vector& b, const Dune::Solvers::DefaultBitVector_t<Vector>& ignore) {

    auto canIgnoreCast = std::dynamic_pointer_cast<CanIgnore<Dune::Solvers::DefaultBitVector_t<Vector>>>( linearSolver );
    if (canIgnoreCast)
      canIgnoreCast->setIgnore(ignore);
    else
      DUNE_THROW(Dune::Exception, "LinearCorrection: linearSolver cannot set ignore member!");

    linearSolver->setProblem(A, x, b);
    linearSolver->preprocess();
    linearSolver->solve();
  };
}

template<typename Matrix, typename Vector>
LinearCorrection<Matrix, Vector>
buildLinearCorrection(std::shared_ptr< Dune::Solvers::IterativeSolver<Vector> > iterativeSolver)
{
  return [=](const Matrix& A, Vector& x, const Vector& b, const Dune::Solvers::DefaultBitVector_t<Vector>& ignore) {

      // compute reference solution directly
      LocalProblem<Matrix, Vector> localProblem(A, b, ignore);
      Vector newR, directX;
      Dune::Solvers::resizeInitializeZero(directX, b);
      localProblem.getLocalRhs(directX, newR);

      /*print(*this->ignoreNodes_, "ignoreNodes:");
      print(A, "A:");
      print(localProblem.getMat(), "localMat:");*/

      auto directSolver = std::make_shared<Dune::Solvers::UMFPackSolver<Matrix, Vector>>();

      directSolver->setProblem(localProblem.getMat(), directX, newR);
      directSolver->preprocess();
      directSolver->solve();

      //x = directX;
    using LinearIterationStep = Dune::Solvers::LinearIterationStep<Matrix, Vector>;

    auto linearIterationStep = dynamic_cast<LinearIterationStep*>(&iterativeSolver->getIterationStep());
    if (not linearIterationStep)
      DUNE_THROW(Dune::Exception, "iterative solver must use a linear iteration step");

    auto empty = emptyIgnore(x);
    linearIterationStep->setIgnore(empty);
    //linearIterationStep->setIgnore(ignore);
    linearIterationStep->setProblem(A, x, b);
    iterativeSolver->preprocess();
    iterativeSolver->solve();


    const auto& norm = iterativeSolver->getErrorNorm();
    auto error = norm.diff(linearIterationStep->getSol(), directX);

    std::cout << "Linear solver iterations: " << iterativeSolver->getResult().iterations << " Error: " << error << std::endl;
  };
}

template<typename Matrix, typename Vector>
LinearCorrection<Matrix, Vector>
buildLinearCorrection(std::shared_ptr< Dune::Solvers::LinearIterationStep<Matrix, Vector> > linearIterationStep, int nIterationSteps = 1)
{
  return [=](const Matrix& A, Vector& x, const Vector& b, const Dune::Solvers::DefaultBitVector_t<Vector>& ignore) {
    linearIterationStep->setIgnore(ignore);
    linearIterationStep->setProblem(A, x, b);
    linearIterationStep->preprocess();

    for (int i = 0; i < nIterationSteps; ++i)
      linearIterationStep->iterate();
  };
}

#endif