Skip to content
Snippets Groups Projects
Commit 3e797be6 authored by Elias Pipping's avatar Elias Pipping
Browse files

[Cleanup] Introduce FixedPointIterator class

parent b227426f
No related branches found
No related tags found
No related merge requests found
......@@ -95,6 +95,91 @@ void initPython() {
Python::run("sys.path.append('" datadir "')");
}
template <class Factory, class StateUpdater, class VelocityUpdater>
class FixedPointIterator {
using ScalarVector = typename StateUpdater::ScalarVector;
using Vector = typename Factory::Vector;
using Matrix = typename Factory::Matrix;
using ConvexProblem = typename Factory::ConvexProblem;
using BlockProblem = typename Factory::BlockProblem;
using Nonlinearity = typename ConvexProblem::NonlinearityType;
public:
FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset,
std::shared_ptr<Nonlinearity> globalFriction)
: factory_(factory),
parset_(parset),
globalFriction_(globalFriction),
fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")),
fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")),
lambda_(parset.get<double>("v.fpi.lambda")),
velocityMaxIterations_(
parset.get<size_t>("v.solver.maximumIterations")),
velocityTolerance_(parset.get<double>("v.solver.tolerance")),
verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")) {}
void run(std::shared_ptr<StateUpdater> stateUpdater,
std::shared_ptr<VelocityUpdater> velocityUpdater,
Matrix const &velocityMatrix, Norm<Vector> const &velocityMatrixNorm,
Vector const &velocityRHS, Vector &velocityIterate) {
auto multigridStep = factory_.getSolver();
LoopSolver<Vector> velocityProblemSolver(
multigridStep, velocityMaxIterations_, velocityTolerance_,
&velocityMatrixNorm, verbosity_, false); // absolute error
Vector previousVelocityIterate;
for (size_t fixedPointIteration = 1;
fixedPointIteration <= fixedPointMaxIterations_;
++fixedPointIteration) {
Vector v_m;
velocityUpdater->extractOldVelocity(v_m);
v_m *= 1.0 - lambda_;
Arithmetic::addProduct(v_m, lambda_, velocityIterate);
// solve a state problem
stateUpdater->solve(v_m);
ScalarVector alpha;
stateUpdater->extractAlpha(alpha);
// solve a velocity problem
globalFriction_->updateAlpha(alpha);
ConvexProblem convexProblem(1.0, velocityMatrix, *globalFriction_,
velocityRHS, velocityIterate);
BlockProblem velocityProblem(parset_, convexProblem);
multigridStep->setProblem(velocityIterate, velocityProblem);
velocityProblemSolver.preprocess();
velocityProblemSolver.solve();
if (fixedPointIteration > 1) {
auto const velocityCorrection =
velocityMatrixNorm.diff(previousVelocityIterate, velocityIterate);
if (velocityCorrection < fixedPointTolerance_)
break;
}
if (fixedPointIteration == fixedPointMaxIterations_)
DUNE_THROW(Dune::Exception, "FPI failed to converge");
previousVelocityIterate = velocityIterate;
}
velocityUpdater->postProcess(velocityIterate);
velocityUpdater->postProcessRelativeQuantities();
}
private:
Factory &factory_;
Dune::ParameterTree const &parset_;
std::shared_ptr<Nonlinearity> globalFriction_;
size_t fixedPointMaxIterations_;
double fixedPointTolerance_;
double lambda_;
size_t velocityMaxIterations_;
double velocityTolerance_;
Solver::VerbosityMode verbosity_;
};
int main(int argc, char *argv[]) {
try {
Dune::ParameterTree parset;
......@@ -338,7 +423,6 @@ int main(int argc, char *argv[]) {
Grid>;
NonlinearFactory factory(parset.sub("solver.tnnmg"), refinements, *grid,
dirichletNodes);
auto multigridStep = factory.getSolver();
auto velocityUpdater = initTimeStepper(
parset.get<Config::scheme>("timeSteps.scheme"),
......@@ -350,19 +434,8 @@ int main(int argc, char *argv[]) {
parset.get<double>("boundary.friction.L"),
parset.get<double>("boundary.friction.V0"));
Vector v_m(leafVertexCount);
ScalarVector alpha(leafVertexCount);
auto const timeSteps = parset.get<size_t>("timeSteps.number"),
maximumStateFPI = parset.get<size_t>("v.fpi.maximumIterations"),
maximumIterations =
parset.get<size_t>("v.solver.maximumIterations");
auto const tau = parset.get<double>("problem.finalTime") / timeSteps,
tolerance = parset.get<double>("v.solver.tolerance"),
fixedPointTolerance = parset.get<double>("v.fpi.tolerance");
auto const verbosity =
parset.get<Solver::VerbosityMode>("v.solver.verbosity");
auto const lambda = parset.get<double>("v.fpi.lambda");
auto const timeSteps = parset.get<size_t>("timeSteps.number");
auto const tau = parset.get<double>("problem.finalTime") / timeSteps;
for (size_t timeStep = 1; timeStep <= timeSteps; ++timeStep) {
stateUpdater->nextTimeStep();
velocityUpdater->nextTimeStep();
......@@ -379,60 +452,20 @@ int main(int argc, char *argv[]) {
velocityUpdater->setup(ell, tau, relativeTime, velocityRHS,
velocityIterate, velocityMatrix);
LoopSolver<Vector> velocityProblemSolver(multigridStep, maximumIterations,
tolerance, &AMNorm, verbosity,
false); // absolute error
size_t iterationCounter;
auto solveVelocityProblem = [&](Vector &_velocityIterate,
ScalarVector const &_alpha) {
myGlobalFriction->updateAlpha(_alpha);
// NIT: Do we really need to pass u here?
typename NonlinearFactory::ConvexProblem convexProblem(
1.0, velocityMatrix, *myGlobalFriction, velocityRHS,
_velocityIterate);
typename NonlinearFactory::BlockProblem velocityProblem(parset,
convexProblem);
multigridStep->setProblem(_velocityIterate, velocityProblem);
velocityProblemSolver.preprocess();
velocityProblemSolver.solve();
iterationCounter = velocityProblemSolver.getResult().iterations;
};
Vector v_saved;
for (size_t stateFPI = 1; stateFPI <= maximumStateFPI; ++stateFPI) {
velocityUpdater->extractOldVelocity(v_m);
v_m *= 1.0 - lambda;
Arithmetic::addProduct(v_m, lambda, velocityIterate);
stateUpdater->solve(v_m);
stateUpdater->extractAlpha(alpha);
solveVelocityProblem(velocityIterate, alpha);
if (stateFPI > 1) {
double const velocityCorrection =
AMNorm.diff(v_saved, velocityIterate);
if (velocityCorrection < fixedPointTolerance)
break;
}
if (stateFPI == maximumStateFPI)
DUNE_THROW(Dune::Exception, "FPI failed to converge");
v_saved = velocityIterate;
}
velocityUpdater->postProcess(velocityIterate);
FixedPointIterator<NonlinearFactory, StateUpdater<ScalarVector, Vector>,
TimeSteppingScheme<Vector, Matrix, Function, dims>>
fixedPointIterator(factory, parset, myGlobalFriction);
fixedPointIterator.run(stateUpdater, velocityUpdater, velocityMatrix,
AMNorm, velocityRHS, velocityIterate);
Vector u, ur, vr;
ScalarVector alpha;
velocityUpdater->extractDisplacement(u);
velocityUpdater->postProcessRelativeQuantities();
velocityUpdater->extractRelativeDisplacement(ur);
velocityUpdater->extractRelativeVelocity(vr);
stateUpdater->extractAlpha(alpha);
report(ur, vr, alpha);
{
BasisGridFunction<typename MyAssembler::VertexBasis, Vector>
relativeVelocity(myAssembler.vertexBasis, vr);
......
#ifndef STATE_UPDATER_HH
#define STATE_UPDATER_HH
template <class ScalarVector, class Vector> class StateUpdater {
template <class ScalarVectorTEMPLATE, class Vector> class StateUpdater {
public:
using ScalarVector = ScalarVectorTEMPLATE;
void virtual nextTimeStep() = 0;
void virtual setup(double _tau) = 0;
void virtual solve(Vector const &velocity_field) = 0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment