Skip to content
Snippets Groups Projects
fixedpointiterator.cc 3.4 KiB
Newer Older
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <dune/common/exceptions.hh>

#include <dune/solvers/common/arithmetic.hh>
#include <dune/solvers/norms/energynorm.hh>
#include <dune/solvers/solvers/loopsolver.hh>

#include "enums.hh"
#include "enumparser.hh"

#include "fixedpointiterator.hh"

template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm>
FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::
    FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset,
                       std::shared_ptr<Nonlinearity> globalFriction,
                       ErrorNorm const &errorNorm)
    : step_(factory.getStep()),
      parset_(parset),
      globalFriction_(globalFriction),
      fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")),
      fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")),
      lambda_(parset.get<double>("v.fpi.lambda")),
      velocityMaxIterations_(parset.get<size_t>("v.solver.maximumIterations")),
      velocityTolerance_(parset.get<double>("v.solver.tolerance")),
      verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")),
      errorNorm_(errorNorm) {}
template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm>
FixedPointIterationCounter
FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::run(
    std::shared_ptr<StateUpdater> stateUpdater,
    std::shared_ptr<RateUpdater> velocityUpdater, Matrix const &velocityMatrix,
    Vector const &velocityRHS, Vector &velocityIterate) {
  EnergyNorm<Matrix, Vector> energyNorm(velocityMatrix);
  LoopSolver<Vector> velocityProblemSolver(step_.get(), velocityMaxIterations_,
                                           velocityTolerance_, &energyNorm,
                                           verbosity_, false); // absolute error

  size_t fixedPointIteration;
  size_t multigridIterations = 0;
  ScalarVector alpha;
  stateUpdater->extractAlpha(alpha);
  for (fixedPointIteration = 0; fixedPointIteration < fixedPointMaxIterations_;
       ++fixedPointIteration) {
    // solve a velocity problem
    globalFriction_->updateAlpha(alpha);
    ConvexProblem convexProblem(1.0, velocityMatrix, *globalFriction_,
                                velocityRHS, velocityIterate);
    BlockProblem velocityProblem(parset_, convexProblem);
    step_->setProblem(velocityIterate, velocityProblem);
    velocityProblemSolver.preprocess();
    velocityProblemSolver.solve();

    multigridIterations += velocityProblemSolver.getResult().iterations;

    Vector v_m;
    velocityUpdater->extractOldVelocity(v_m);
    v_m *= 1.0 - lambda_;
    Arithmetic::addProduct(v_m, lambda_, velocityIterate);

    // solve a state problem
    stateUpdater->solve(v_m);
    ScalarVector newAlpha;
    stateUpdater->extractAlpha(newAlpha);

    if (lambda_ < 1e-12 or
        errorNorm_.diff(alpha, newAlpha) < fixedPointTolerance_) {
      fixedPointIteration++;
    alpha = newAlpha;
  }
  if (fixedPointIteration == fixedPointMaxIterations_)
    DUNE_THROW(Dune::Exception, "FPI failed to converge");

  velocityUpdater->postProcess(velocityIterate);

  return { fixedPointIteration, multigridIterations };
}

std::ostream &operator<<(std::ostream &stream,
                         FixedPointIterationCounter const &fpic) {
  return stream << "(" << fpic.iterations << "," << fpic.multigridIterations
                << ")";
}

#include "fixedpointiterator_tmpl.cc"