diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc index b3dc3dd820c0d2cac3c7cc304b95757998cd0c3b..ee8807ee7741cb02a1091ead2a652dafdf647cac 100644 --- a/src/sand-wedge.cc +++ b/src/sand-wedge.cc @@ -95,6 +95,91 @@ void initPython() { Python::run("sys.path.append('" datadir "')"); } +template <class Factory, class StateUpdater, class VelocityUpdater> +class FixedPointIterator { + using ScalarVector = typename StateUpdater::ScalarVector; + using Vector = typename Factory::Vector; + using Matrix = typename Factory::Matrix; + using ConvexProblem = typename Factory::ConvexProblem; + using BlockProblem = typename Factory::BlockProblem; + using Nonlinearity = typename ConvexProblem::NonlinearityType; + +public: + FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset, + std::shared_ptr<Nonlinearity> globalFriction) + : factory_(factory), + parset_(parset), + globalFriction_(globalFriction), + fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")), + fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")), + lambda_(parset.get<double>("v.fpi.lambda")), + velocityMaxIterations_( + parset.get<size_t>("v.solver.maximumIterations")), + velocityTolerance_(parset.get<double>("v.solver.tolerance")), + verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")) {} + + void run(std::shared_ptr<StateUpdater> stateUpdater, + std::shared_ptr<VelocityUpdater> velocityUpdater, + Matrix const &velocityMatrix, Norm<Vector> const &velocityMatrixNorm, + Vector const &velocityRHS, Vector &velocityIterate) { + auto multigridStep = factory_.getSolver(); + + LoopSolver<Vector> velocityProblemSolver( + multigridStep, velocityMaxIterations_, velocityTolerance_, + &velocityMatrixNorm, verbosity_, false); // absolute error + + Vector previousVelocityIterate; + + for (size_t fixedPointIteration = 1; + fixedPointIteration <= fixedPointMaxIterations_; + ++fixedPointIteration) { + Vector v_m; + velocityUpdater->extractOldVelocity(v_m); + v_m *= 1.0 - lambda_; + Arithmetic::addProduct(v_m, lambda_, velocityIterate); + + // solve a state problem + stateUpdater->solve(v_m); + ScalarVector alpha; + stateUpdater->extractAlpha(alpha); + + // solve a velocity problem + globalFriction_->updateAlpha(alpha); + ConvexProblem convexProblem(1.0, velocityMatrix, *globalFriction_, + velocityRHS, velocityIterate); + BlockProblem velocityProblem(parset_, convexProblem); + multigridStep->setProblem(velocityIterate, velocityProblem); + velocityProblemSolver.preprocess(); + velocityProblemSolver.solve(); + + if (fixedPointIteration > 1) { + auto const velocityCorrection = + velocityMatrixNorm.diff(previousVelocityIterate, velocityIterate); + if (velocityCorrection < fixedPointTolerance_) + break; + } + if (fixedPointIteration == fixedPointMaxIterations_) + DUNE_THROW(Dune::Exception, "FPI failed to converge"); + + previousVelocityIterate = velocityIterate; + } + velocityUpdater->postProcess(velocityIterate); + velocityUpdater->postProcessRelativeQuantities(); + } + +private: + Factory &factory_; + Dune::ParameterTree const &parset_; + std::shared_ptr<Nonlinearity> globalFriction_; + + size_t fixedPointMaxIterations_; + double fixedPointTolerance_; + double lambda_; + size_t velocityMaxIterations_; + double velocityTolerance_; + Solver::VerbosityMode verbosity_; +}; + int main(int argc, char *argv[]) { try { Dune::ParameterTree parset; @@ -338,7 +423,6 @@ int main(int argc, char *argv[]) { Grid>; NonlinearFactory factory(parset.sub("solver.tnnmg"), refinements, *grid, dirichletNodes); - auto multigridStep = factory.getSolver(); auto velocityUpdater = initTimeStepper( parset.get<Config::scheme>("timeSteps.scheme"), @@ -350,19 +434,8 @@ int main(int argc, char *argv[]) { parset.get<double>("boundary.friction.L"), parset.get<double>("boundary.friction.V0")); - Vector v_m(leafVertexCount); - ScalarVector alpha(leafVertexCount); - - auto const timeSteps = parset.get<size_t>("timeSteps.number"), - maximumStateFPI = parset.get<size_t>("v.fpi.maximumIterations"), - maximumIterations = - parset.get<size_t>("v.solver.maximumIterations"); - auto const tau = parset.get<double>("problem.finalTime") / timeSteps, - tolerance = parset.get<double>("v.solver.tolerance"), - fixedPointTolerance = parset.get<double>("v.fpi.tolerance"); - auto const verbosity = - parset.get<Solver::VerbosityMode>("v.solver.verbosity"); - auto const lambda = parset.get<double>("v.fpi.lambda"); + auto const timeSteps = parset.get<size_t>("timeSteps.number"); + auto const tau = parset.get<double>("problem.finalTime") / timeSteps; for (size_t timeStep = 1; timeStep <= timeSteps; ++timeStep) { stateUpdater->nextTimeStep(); velocityUpdater->nextTimeStep(); @@ -379,60 +452,20 @@ int main(int argc, char *argv[]) { velocityUpdater->setup(ell, tau, relativeTime, velocityRHS, velocityIterate, velocityMatrix); - LoopSolver<Vector> velocityProblemSolver(multigridStep, maximumIterations, - tolerance, &AMNorm, verbosity, - false); // absolute error - - size_t iterationCounter; - auto solveVelocityProblem = [&](Vector &_velocityIterate, - ScalarVector const &_alpha) { - myGlobalFriction->updateAlpha(_alpha); - - // NIT: Do we really need to pass u here? - typename NonlinearFactory::ConvexProblem convexProblem( - 1.0, velocityMatrix, *myGlobalFriction, velocityRHS, - _velocityIterate); - typename NonlinearFactory::BlockProblem velocityProblem(parset, - convexProblem); - multigridStep->setProblem(_velocityIterate, velocityProblem); - - velocityProblemSolver.preprocess(); - velocityProblemSolver.solve(); - iterationCounter = velocityProblemSolver.getResult().iterations; - }; - - Vector v_saved; - for (size_t stateFPI = 1; stateFPI <= maximumStateFPI; ++stateFPI) { - velocityUpdater->extractOldVelocity(v_m); - v_m *= 1.0 - lambda; - Arithmetic::addProduct(v_m, lambda, velocityIterate); - - stateUpdater->solve(v_m); - stateUpdater->extractAlpha(alpha); - - solveVelocityProblem(velocityIterate, alpha); - - if (stateFPI > 1) { - double const velocityCorrection = - AMNorm.diff(v_saved, velocityIterate); - if (velocityCorrection < fixedPointTolerance) - break; - } - if (stateFPI == maximumStateFPI) - DUNE_THROW(Dune::Exception, "FPI failed to converge"); - - v_saved = velocityIterate; - } - velocityUpdater->postProcess(velocityIterate); + FixedPointIterator<NonlinearFactory, StateUpdater<ScalarVector, Vector>, + TimeSteppingScheme<Vector, Matrix, Function, dims>> + fixedPointIterator(factory, parset, myGlobalFriction); + fixedPointIterator.run(stateUpdater, velocityUpdater, velocityMatrix, + AMNorm, velocityRHS, velocityIterate); Vector u, ur, vr; + ScalarVector alpha; velocityUpdater->extractDisplacement(u); - velocityUpdater->postProcessRelativeQuantities(); velocityUpdater->extractRelativeDisplacement(ur); velocityUpdater->extractRelativeVelocity(vr); + stateUpdater->extractAlpha(alpha); report(ur, vr, alpha); - { BasisGridFunction<typename MyAssembler::VertexBasis, Vector> relativeVelocity(myAssembler.vertexBasis, vr); diff --git a/src/state/stateupdater.hh b/src/state/stateupdater.hh index 0ebf4774c11bf7825954ec4006a0b870645e7241..697662d32e8bf877441fed9085c7ccaa0bba087c 100644 --- a/src/state/stateupdater.hh +++ b/src/state/stateupdater.hh @@ -1,8 +1,10 @@ #ifndef STATE_UPDATER_HH #define STATE_UPDATER_HH -template <class ScalarVector, class Vector> class StateUpdater { +template <class ScalarVectorTEMPLATE, class Vector> class StateUpdater { public: + using ScalarVector = ScalarVectorTEMPLATE; + void virtual nextTimeStep() = 0; void virtual setup(double _tau) = 0; void virtual solve(Vector const &velocity_field) = 0;