Skip to content
Snippets Groups Projects
Commit 1a3a36cf authored by Elias Pipping's avatar Elias Pipping
Browse files

[Cleanup] Modularise FixedPointIterator

parent 1e54a73f
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ common_sources = \
assemblers.cc \
boundary_writer.cc \
enumparser.cc \
fixedpointiterator.cc \
friction_writer.cc \
solverfactory.cc \
state.cc \
......
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include <dune/common/exceptions.hh>
#include <dune/solvers/common/arithmetic.hh>
#include <dune/solvers/solvers/loopsolver.hh>
#include "enums.hh"
#include "enumparser.hh"
#include "fixedpointiterator.hh"
template <class Factory, class StateUpdater, class VelocityUpdater>
FixedPointIterator<Factory, StateUpdater, VelocityUpdater>::FixedPointIterator(
Factory &factory, Dune::ParameterTree const &parset,
std::shared_ptr<Nonlinearity> globalFriction)
: factory_(factory),
parset_(parset),
globalFriction_(globalFriction),
fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")),
fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")),
lambda_(parset.get<double>("v.fpi.lambda")),
velocityMaxIterations_(parset.get<size_t>("v.solver.maximumIterations")),
velocityTolerance_(parset.get<double>("v.solver.tolerance")),
verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")) {}
template <class Factory, class StateUpdater, class VelocityUpdater>
int FixedPointIterator<Factory, StateUpdater, VelocityUpdater>::run(
std::shared_ptr<StateUpdater> stateUpdater,
std::shared_ptr<VelocityUpdater> velocityUpdater,
Matrix const &velocityMatrix, Norm<Vector> const &velocityMatrixNorm,
Vector const &velocityRHS, Vector &velocityIterate) {
auto multigridStep = factory_.getSolver();
LoopSolver<Vector> velocityProblemSolver(
multigridStep, velocityMaxIterations_, velocityTolerance_,
&velocityMatrixNorm, verbosity_, false); // absolute error
Vector previousVelocityIterate = velocityIterate;
size_t fixedPointIteration;
for (fixedPointIteration = 1; fixedPointIteration <= fixedPointMaxIterations_;
++fixedPointIteration) {
Vector v_m;
velocityUpdater->extractOldVelocity(v_m);
v_m *= 1.0 - lambda_;
Arithmetic::addProduct(v_m, lambda_, velocityIterate);
// solve a state problem
stateUpdater->solve(v_m);
ScalarVector alpha;
stateUpdater->extractAlpha(alpha);
// solve a velocity problem
globalFriction_->updateAlpha(alpha);
ConvexProblem convexProblem(1.0, velocityMatrix, *globalFriction_,
velocityRHS, velocityIterate);
BlockProblem velocityProblem(parset_, convexProblem);
multigridStep->setProblem(velocityIterate, velocityProblem);
velocityProblemSolver.preprocess();
velocityProblemSolver.solve();
if (velocityMatrixNorm.diff(previousVelocityIterate, velocityIterate) <
fixedPointTolerance_)
break;
previousVelocityIterate = velocityIterate;
}
if (fixedPointIteration == fixedPointMaxIterations_)
DUNE_THROW(Dune::Exception, "FPI failed to converge");
velocityUpdater->postProcess(velocityIterate);
velocityUpdater->postProcessRelativeQuantities();
return fixedPointIteration;
}
#include "fixedpointiterator_tmpl.cc"
#ifndef SRC_FIXEDPOINTITERATOR_HH
#define SRC_FIXEDPOINTITERATOR_HH
#include <memory>
#include <dune/common/parametertree.hh>
#include <dune/solvers/norms/norm.hh>
#include <dune/solvers/solvers/solver.hh>
template <class Factory, class StateUpdater, class VelocityUpdater>
class FixedPointIterator {
using ScalarVector = typename StateUpdater::ScalarVector;
using Vector = typename Factory::Vector;
using Matrix = typename Factory::Matrix;
using ConvexProblem = typename Factory::ConvexProblem;
using BlockProblem = typename Factory::BlockProblem;
using Nonlinearity = typename ConvexProblem::NonlinearityType;
public:
FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset,
std::shared_ptr<Nonlinearity> globalFriction);
int run(std::shared_ptr<StateUpdater> stateUpdater,
std::shared_ptr<VelocityUpdater> velocityUpdater,
Matrix const &velocityMatrix, Norm<Vector> const &velocityMatrixNorm,
Vector const &velocityRHS, Vector &velocityIterate);
private:
Factory &factory_;
Dune::ParameterTree const &parset_;
std::shared_ptr<Nonlinearity> globalFriction_;
size_t fixedPointMaxIterations_;
double fixedPointTolerance_;
double lambda_;
size_t velocityMaxIterations_;
double velocityTolerance_;
Solver::VerbosityMode verbosity_;
};
#endif
#ifndef DIM
#error DIM unset
#endif
#include <dune/common/function.hh>
#include <dune/tnnmg/problem-classes/convexproblem.hh>
#include <dune/tectonic/globalfriction.hh>
#include <dune/tectonic/myblockproblem.hh>
#include "explicitgrid.hh"
#include "explicitvectors.hh"
#include "solverfactory.hh"
#include "state/stateupdater.hh"
#include "timestepping.hh"
using Function = Dune::VirtualFunction<double, double>;
using Factory = SolverFactory<
DIM, MyBlockProblem<ConvexProblem<GlobalFriction<Matrix, Vector>, Matrix>>,
Grid>;
using MyStateUpdater = StateUpdater<ScalarVector, Vector>;
using MyVelocityUpdater = TimeSteppingScheme<Vector, Matrix, Function, DIM>;
template class FixedPointIterator<Factory, MyStateUpdater, MyVelocityUpdater>;
......@@ -70,6 +70,7 @@
#include "tobool.hh"
#include "enumparser.hh"
#include "enums.hh"
#include "fixedpointiterator.hh"
#include "friction_writer.hh"
#include "sand-wedge-data/mybody.hh"
#include "sand-wedge-data/mygeometry.hh"
......@@ -92,92 +93,6 @@ void initPython() {
Python::run("sys.path.append('" datadir "')");
}
template <class Factory, class StateUpdater, class VelocityUpdater>
class FixedPointIterator {
using ScalarVector = typename StateUpdater::ScalarVector;
using Vector = typename Factory::Vector;
using Matrix = typename Factory::Matrix;
using ConvexProblem = typename Factory::ConvexProblem;
using BlockProblem = typename Factory::BlockProblem;
using Nonlinearity = typename ConvexProblem::NonlinearityType;
public:
FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset,
std::shared_ptr<Nonlinearity> globalFriction)
: factory_(factory),
parset_(parset),
globalFriction_(globalFriction),
fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")),
fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")),
lambda_(parset.get<double>("v.fpi.lambda")),
velocityMaxIterations_(
parset.get<size_t>("v.solver.maximumIterations")),
velocityTolerance_(parset.get<double>("v.solver.tolerance")),
verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")) {}
int run(std::shared_ptr<StateUpdater> stateUpdater,
std::shared_ptr<VelocityUpdater> velocityUpdater,
Matrix const &velocityMatrix, Norm<Vector> const &velocityMatrixNorm,
Vector const &velocityRHS, Vector &velocityIterate) {
auto multigridStep = factory_.getSolver();
LoopSolver<Vector> velocityProblemSolver(
multigridStep, velocityMaxIterations_, velocityTolerance_,
&velocityMatrixNorm, verbosity_, false); // absolute error
Vector previousVelocityIterate = velocityIterate;
size_t fixedPointIteration;
for (fixedPointIteration = 1;
fixedPointIteration <= fixedPointMaxIterations_;
++fixedPointIteration) {
Vector v_m;
velocityUpdater->extractOldVelocity(v_m);
v_m *= 1.0 - lambda_;
Arithmetic::addProduct(v_m, lambda_, velocityIterate);
// solve a state problem
stateUpdater->solve(v_m);
ScalarVector alpha;
stateUpdater->extractAlpha(alpha);
// solve a velocity problem
globalFriction_->updateAlpha(alpha);
ConvexProblem convexProblem(1.0, velocityMatrix, *globalFriction_,
velocityRHS, velocityIterate);
BlockProblem velocityProblem(parset_, convexProblem);
multigridStep->setProblem(velocityIterate, velocityProblem);
velocityProblemSolver.preprocess();
velocityProblemSolver.solve();
if (velocityMatrixNorm.diff(previousVelocityIterate, velocityIterate) <
fixedPointTolerance_)
break;
previousVelocityIterate = velocityIterate;
}
if (fixedPointIteration == fixedPointMaxIterations_)
DUNE_THROW(Dune::Exception, "FPI failed to converge");
velocityUpdater->postProcess(velocityIterate);
velocityUpdater->postProcessRelativeQuantities();
return fixedPointIteration;
}
private:
Factory &factory_;
Dune::ParameterTree const &parset_;
std::shared_ptr<Nonlinearity> globalFriction_;
size_t fixedPointMaxIterations_;
double fixedPointTolerance_;
double lambda_;
size_t velocityMaxIterations_;
double velocityTolerance_;
Solver::VerbosityMode verbosity_;
};
template <class Factory, class StateUpdater, class VelocityUpdater>
class CoupledTimeStepper {
using Vector = typename Factory::Vector;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment