From 3cb3ac5d844fe2b5426f6fa866aa86a895aaaf2e Mon Sep 17 00:00:00 2001 From: Elias Pipping <elias.pipping@fu-berlin.de> Date: Wed, 11 Feb 2015 17:41:07 +0100 Subject: [PATCH] [Cleanup] Add Updaters class --- src/adaptivetimestepper.hh | 70 +++++++++++++++------------------- src/coupledtimestepper.cc | 45 ++++++++++------------ src/coupledtimestepper.hh | 9 ++--- src/coupledtimestepper_tmpl.cc | 5 ++- src/fixedpointiterator.cc | 31 ++++++++------- src/fixedpointiterator.hh | 7 ++-- src/fixedpointiterator_tmpl.cc | 5 ++- src/sand-wedge.cc | 34 ++++++++--------- src/updaters.hh | 23 +++++++++++ 9 files changed, 118 insertions(+), 111 deletions(-) create mode 100644 src/updaters.hh diff --git a/src/adaptivetimestepper.hh b/src/adaptivetimestepper.hh index 9720efa6..2636cc1e 100644 --- a/src/adaptivetimestepper.hh +++ b/src/adaptivetimestepper.hh @@ -1,29 +1,21 @@ #include "coupledtimestepper.hh" -template <typename T1, typename T2> -std::pair<T1, T2> clonePair(std::pair<T1, T2> in) { - return { in.first->clone(), in.second->clone() }; -} - -template <class Factory, class UpdaterPair, class ErrorNorm> +template <class Factory, class Updaters, class ErrorNorm> class AdaptiveTimeStepper { - using StateUpdater = typename UpdaterPair::first_type::element_type; - using RateUpdater = typename UpdaterPair::second_type::element_type; using Vector = typename Factory::Vector; using ConvexProblem = typename Factory::ConvexProblem; using Nonlinearity = typename ConvexProblem::NonlinearityType; - using MyCoupledTimeStepper = - CoupledTimeStepper<Factory, StateUpdater, RateUpdater, ErrorNorm>; + using MyCoupledTimeStepper = CoupledTimeStepper<Factory, Updaters, ErrorNorm>; public: - AdaptiveTimeStepper( - Factory &factory, Dune::ParameterTree const &parset, - std::shared_ptr<Nonlinearity> globalFriction, UpdaterPair ¤t, - double relativeTime, double relativeTau, - std::function<void(double, Vector &)> externalForces, - ErrorNorm const &errorNorm, - std::function<bool(UpdaterPair &, UpdaterPair &)> mustRefine) + AdaptiveTimeStepper(Factory &factory, Dune::ParameterTree const &parset, + std::shared_ptr<Nonlinearity> globalFriction, + Updaters ¤t, double relativeTime, + double relativeTau, + std::function<void(double, Vector &)> externalForces, + ErrorNorm const &errorNorm, + std::function<bool(Updaters &, Updaters &)> mustRefine) : finalTime_(parset.get<double>("problem.finalTime")), relativeTime_(relativeTime), relativeTau_(relativeTau), @@ -31,14 +23,14 @@ class AdaptiveTimeStepper { parset_(parset), globalFriction_(globalFriction), current_(current), - R1_(clonePair(current_)), + R1_(current_.clone()), externalForces_(externalForces), mustRefine_(mustRefine), iterationWriter_("iterations", std::fstream::out), errorNorm_(errorNorm) { - MyCoupledTimeStepper coupledTimeStepper( - finalTime_, factory_, parset_, globalFriction_, R1_.first, R1_.second, - errorNorm, externalForces_); + MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_, + globalFriction_, R1_, errorNorm, + externalForces_); stepAndReport("R1", coupledTimeStepper, relativeTime_, relativeTau_); iterationWriter_ << std::endl; } @@ -49,20 +41,20 @@ class AdaptiveTimeStepper { bool didCoarsen = false; while (relativeTime_ + relativeTau_ < 1.0) { - R2_ = clonePair(R1_); + R2_ = R1_.clone(); { - MyCoupledTimeStepper coupledTimeStepper( - finalTime_, factory_, parset_, globalFriction_, R2_.first, - R2_.second, errorNorm_, externalForces_); + MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_, + globalFriction_, R2_, + errorNorm_, externalForces_); stepAndReport("R2", coupledTimeStepper, relativeTime_ + relativeTau_, relativeTau_); } - UpdaterPair C = clonePair(current_); + Updaters C = current_.clone(); { - MyCoupledTimeStepper coupledTimeStepper( - finalTime_, factory_, parset_, globalFriction_, C.first, C.second, - errorNorm_, externalForces_); + MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_, + globalFriction_, C, errorNorm_, + externalForces_); stepAndReport("C", coupledTimeStepper, relativeTime_, 2.0 * relativeTau_); @@ -82,16 +74,16 @@ class AdaptiveTimeStepper { void refine() { while (true) { - UpdaterPair F2 = clonePair(current_); - UpdaterPair F1; + Updaters F2 = current_.clone(); + Updaters F1; { - MyCoupledTimeStepper coupledTimeStepper( - finalTime_, factory_, parset_, globalFriction_, F2.first, F2.second, - errorNorm_, externalForces_); + MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_, + globalFriction_, F2, errorNorm_, + externalForces_); stepAndReport("F1", coupledTimeStepper, relativeTime_, relativeTau_ / 2.0); - F1 = clonePair(F2); + F1 = F2.clone(); stepAndReport("F2", coupledTimeStepper, relativeTime_ + relativeTau_ / 2.0, relativeTau_ / 2.0); } @@ -133,11 +125,11 @@ class AdaptiveTimeStepper { Factory &factory_; Dune::ParameterTree const &parset_; std::shared_ptr<Nonlinearity> globalFriction_; - UpdaterPair ¤t_; - UpdaterPair R1_; - UpdaterPair R2_; + Updaters ¤t_; + Updaters R1_; + Updaters R2_; std::function<void(double, Vector &)> externalForces_; - std::function<bool(UpdaterPair &, UpdaterPair &)> mustRefine_; + std::function<bool(Updaters &, Updaters &)> mustRefine_; std::fstream iterationWriter_; ErrorNorm const &errorNorm_; }; diff --git a/src/coupledtimestepper.cc b/src/coupledtimestepper.cc index 1c12c5c9..3be5824d 100644 --- a/src/coupledtimestepper.cc +++ b/src/coupledtimestepper.cc @@ -4,30 +4,26 @@ #include "coupledtimestepper.hh" -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> -CoupledTimeStepper<Factory, StateUpdater, RateUpdater, ErrorNorm>:: - CoupledTimeStepper(double finalTime, Factory &factory, - Dune::ParameterTree const &parset, - std::shared_ptr<Nonlinearity> globalFriction, - std::shared_ptr<StateUpdater> stateUpdater, - std::shared_ptr<RateUpdater> velocityUpdater, - ErrorNorm const &errorNorm, - std::function<void(double, Vector &)> externalForces) +template <class Factory, class Updaters, class ErrorNorm> +CoupledTimeStepper<Factory, Updaters, ErrorNorm>::CoupledTimeStepper( + double finalTime, Factory &factory, Dune::ParameterTree const &parset, + std::shared_ptr<Nonlinearity> globalFriction, Updaters updaters, + ErrorNorm const &errorNorm, + std::function<void(double, Vector &)> externalForces) : finalTime_(finalTime), factory_(factory), parset_(parset), globalFriction_(globalFriction), - stateUpdater_(stateUpdater), - velocityUpdater_(velocityUpdater), + updaters_(updaters), externalForces_(externalForces), errorNorm_(errorNorm) {} -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> -FixedPointIterationCounter CoupledTimeStepper< - Factory, StateUpdater, RateUpdater, ErrorNorm>::step(double relativeTime, - double relativeTau) { - stateUpdater_->nextTimeStep(); - velocityUpdater_->nextTimeStep(); +template <class Factory, class Updaters, class ErrorNorm> +FixedPointIterationCounter +CoupledTimeStepper<Factory, Updaters, ErrorNorm>::step(double relativeTime, + double relativeTau) { + updaters_.state_->nextTimeStep(); + updaters_.rate_->nextTimeStep(); auto const newRelativeTime = relativeTime + relativeTau; Vector ell; @@ -38,14 +34,13 @@ FixedPointIterationCounter CoupledTimeStepper< Vector velocityIterate; auto const tau = relativeTau * finalTime_; - stateUpdater_->setup(tau); - velocityUpdater_->setup(ell, tau, newRelativeTime, velocityRHS, - velocityIterate, velocityMatrix); - FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm> - fixedPointIterator(factory_, parset_, globalFriction_, errorNorm_); - auto const iterations = - fixedPointIterator.run(stateUpdater_, velocityUpdater_, velocityMatrix, - velocityRHS, velocityIterate); + updaters_.state_->setup(tau); + updaters_.rate_->setup(ell, tau, newRelativeTime, velocityRHS, + velocityIterate, velocityMatrix); + FixedPointIterator<Factory, Updaters, ErrorNorm> fixedPointIterator( + factory_, parset_, globalFriction_, errorNorm_); + auto const iterations = fixedPointIterator.run(updaters_, velocityMatrix, + velocityRHS, velocityIterate); return iterations; } diff --git a/src/coupledtimestepper.hh b/src/coupledtimestepper.hh index 98c4b0e9..8873e675 100644 --- a/src/coupledtimestepper.hh +++ b/src/coupledtimestepper.hh @@ -7,7 +7,7 @@ #include <dune/common/parametertree.hh> #include "fixedpointiterator.hh" -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> +template <class Factory, class Updaters, class ErrorNorm> class CoupledTimeStepper { using Vector = typename Factory::Vector; using Matrix = typename Factory::Matrix; @@ -18,9 +18,7 @@ class CoupledTimeStepper { CoupledTimeStepper(double finalTime, Factory &factory, Dune::ParameterTree const &parset, std::shared_ptr<Nonlinearity> globalFriction, - std::shared_ptr<StateUpdater> stateUpdater, - std::shared_ptr<RateUpdater> velocityUpdater, - ErrorNorm const &errorNorm, + Updaters updaters, ErrorNorm const &errorNorm, std::function<void(double, Vector &)> externalForces); FixedPointIterationCounter step(double relativeTime, double relativeTau); @@ -30,8 +28,7 @@ class CoupledTimeStepper { Factory &factory_; Dune::ParameterTree const &parset_; std::shared_ptr<Nonlinearity> globalFriction_; - std::shared_ptr<StateUpdater> stateUpdater_; - std::shared_ptr<RateUpdater> velocityUpdater_; + Updaters updaters_; std::function<void(double, Vector &)> externalForces_; ErrorNorm const &errorNorm_; }; diff --git a/src/coupledtimestepper_tmpl.cc b/src/coupledtimestepper_tmpl.cc index 059922fd..f10309db 100644 --- a/src/coupledtimestepper_tmpl.cc +++ b/src/coupledtimestepper_tmpl.cc @@ -16,6 +16,7 @@ #include "rate/rateupdater.hh" #include "solverfactory.hh" #include "state/stateupdater.hh" +#include "updaters.hh" using Function = Dune::VirtualFunction<double, double>; using Factory = SolverFactory< @@ -24,8 +25,8 @@ using Factory = SolverFactory< Grid>; using MyStateUpdater = StateUpdater<ScalarVector, Vector>; using MyRateUpdater = RateUpdater<Vector, Matrix, Function, MY_DIM>; +using MyUpdaters = Updaters<MyRateUpdater, MyStateUpdater>; using ErrorNorm = EnergyNorm<ScalarMatrix, ScalarVector>; -template class CoupledTimeStepper<Factory, MyStateUpdater, MyRateUpdater, - ErrorNorm>; +template class CoupledTimeStepper<Factory, MyUpdaters, ErrorNorm>; diff --git a/src/fixedpointiterator.cc b/src/fixedpointiterator.cc index 67caf40c..92c5c6cf 100644 --- a/src/fixedpointiterator.cc +++ b/src/fixedpointiterator.cc @@ -13,11 +13,10 @@ #include "fixedpointiterator.hh" -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> -FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>:: - FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset, - std::shared_ptr<Nonlinearity> globalFriction, - ErrorNorm const &errorNorm) +template <class Factory, class Updaters, class ErrorNorm> +FixedPointIterator<Factory, Updaters, ErrorNorm>::FixedPointIterator( + Factory &factory, Dune::ParameterTree const &parset, + std::shared_ptr<Nonlinearity> globalFriction, ErrorNorm const &errorNorm) : step_(factory.getStep()), parset_(parset), globalFriction_(globalFriction), @@ -29,12 +28,12 @@ FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>:: verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")), errorNorm_(errorNorm) {} -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> -FixedPointIterationCounter -FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::run( - std::shared_ptr<StateUpdater> stateUpdater, - std::shared_ptr<RateUpdater> velocityUpdater, Matrix const &velocityMatrix, - Vector const &velocityRHS, Vector &velocityIterate) { +template <class Factory, class Updaters, class ErrorNorm> +FixedPointIterationCounter FixedPointIterator< + Factory, Updaters, ErrorNorm>::run(Updaters updaters, + Matrix const &velocityMatrix, + Vector const &velocityRHS, + Vector &velocityIterate) { EnergyNorm<Matrix, Vector> energyNorm(velocityMatrix); LoopSolver<Vector> velocityProblemSolver(step_.get(), velocityMaxIterations_, velocityTolerance_, &energyNorm, @@ -43,7 +42,7 @@ FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::run( size_t fixedPointIteration; size_t multigridIterations = 0; ScalarVector alpha; - stateUpdater->extractAlpha(alpha); + updaters.state_->extractAlpha(alpha); for (fixedPointIteration = 0; fixedPointIteration < fixedPointMaxIterations_; ++fixedPointIteration) { // solve a velocity problem @@ -58,14 +57,14 @@ FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::run( multigridIterations += velocityProblemSolver.getResult().iterations; Vector v_m; - velocityUpdater->extractOldVelocity(v_m); + updaters.rate_->extractOldVelocity(v_m); v_m *= 1.0 - lambda_; Arithmetic::addProduct(v_m, lambda_, velocityIterate); // solve a state problem - stateUpdater->solve(v_m); + updaters.state_->solve(v_m); ScalarVector newAlpha; - stateUpdater->extractAlpha(newAlpha); + updaters.state_->extractAlpha(newAlpha); if (lambda_ < 1e-12 or errorNorm_.diff(alpha, newAlpha) < fixedPointTolerance_) { @@ -77,7 +76,7 @@ FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::run( if (fixedPointIteration == fixedPointMaxIterations_) DUNE_THROW(Dune::Exception, "FPI failed to converge"); - velocityUpdater->postProcess(velocityIterate); + updaters.rate_->postProcess(velocityIterate); return { fixedPointIteration, multigridIterations }; } diff --git a/src/fixedpointiterator.hh b/src/fixedpointiterator.hh index 7543b12e..9a4a336f 100644 --- a/src/fixedpointiterator.hh +++ b/src/fixedpointiterator.hh @@ -16,9 +16,9 @@ struct FixedPointIterationCounter { std::ostream &operator<<(std::ostream &stream, FixedPointIterationCounter const &fpic); -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> +template <class Factory, class Updaters, class ErrorNorm> class FixedPointIterator { - using ScalarVector = typename StateUpdater::ScalarVector; + using ScalarVector = typename Updaters::StateUpdater::ScalarVector; using Vector = typename Factory::Vector; using Matrix = typename Factory::Matrix; using ConvexProblem = typename Factory::ConvexProblem; @@ -30,8 +30,7 @@ class FixedPointIterator { std::shared_ptr<Nonlinearity> globalFriction, ErrorNorm const &errorNorm_); - FixedPointIterationCounter run(std::shared_ptr<StateUpdater> stateUpdater, - std::shared_ptr<RateUpdater> velocityUpdater, + FixedPointIterationCounter run(Updaters updaters, Matrix const &velocityMatrix, Vector const &velocityRHS, Vector &velocityIterate); diff --git a/src/fixedpointiterator_tmpl.cc b/src/fixedpointiterator_tmpl.cc index c39287d6..5975e197 100644 --- a/src/fixedpointiterator_tmpl.cc +++ b/src/fixedpointiterator_tmpl.cc @@ -16,6 +16,7 @@ #include "rate/rateupdater.hh" #include "solverfactory.hh" #include "state/stateupdater.hh" +#include "updaters.hh" using Function = Dune::VirtualFunction<double, double>; using Factory = SolverFactory< @@ -24,8 +25,8 @@ using Factory = SolverFactory< Grid>; using MyStateUpdater = StateUpdater<ScalarVector, Vector>; using MyRateUpdater = RateUpdater<Vector, Matrix, Function, MY_DIM>; +using MyUpdaters = Updaters<MyRateUpdater, MyStateUpdater>; using ErrorNorm = EnergyNorm<ScalarMatrix, ScalarVector>; -template class FixedPointIterator<Factory, MyStateUpdater, MyRateUpdater, - ErrorNorm>; +template class FixedPointIterator<Factory, MyUpdaters, ErrorNorm>; diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc index c5559c6d..6c7588e8 100644 --- a/src/sand-wedge.cc +++ b/src/sand-wedge.cc @@ -70,6 +70,7 @@ #include "sand-wedge-data/weakpatch.hh" #include "solverfactory.hh" #include "state.hh" +#include "updaters.hh" #include "vtk.hh" size_t const dims = MY_DIM; @@ -291,33 +292,32 @@ int main(int argc, char *argv[]) { Grid>; NonlinearFactory factory(parset.sub("solver.tnnmg"), *grid, dirichletNodes); - using UpdaterPair = - std::pair<std::shared_ptr<StateUpdater<ScalarVector, Vector>>, - std::shared_ptr<RateUpdater<Vector, Matrix, Function, dims>>>; - UpdaterPair current( + using MyUpdater = Updaters<RateUpdater<Vector, Matrix, Function, dims>, + StateUpdater<ScalarVector, Vector>>; + MyUpdater current( + initRateUpdater(parset.get<Config::scheme>("timeSteps.scheme"), + velocityDirichletFunction, dirichletNodes, matrices, + programState.u, programState.v, programState.a), initStateUpdater<ScalarVector, Vector>( parset.get<Config::stateModel>("boundary.friction.stateModel"), programState.alpha, frictionalNodes, parset.get<double>("boundary.friction.L"), - parset.get<double>("boundary.friction.V0")), - initRateUpdater(parset.get<Config::scheme>("timeSteps.scheme"), - velocityDirichletFunction, dirichletNodes, matrices, - programState.u, programState.v, programState.a)); + parset.get<double>("boundary.friction.V0"))); auto const refinementTolerance = parset.get<double>("timeSteps.refinementTolerance"); - auto const mustRefine = [&](UpdaterPair &coarseUpdater, - UpdaterPair &fineUpdater) { + auto const mustRefine = [&](MyUpdater &coarseUpdater, + MyUpdater &fineUpdater) { ScalarVector coarseAlpha; - coarseUpdater.first->extractAlpha(coarseAlpha); + coarseUpdater.state_->extractAlpha(coarseAlpha); ScalarVector fineAlpha; - fineUpdater.first->extractAlpha(fineAlpha); + fineUpdater.state_->extractAlpha(fineAlpha); return stateEnergyNorm.diff(fineAlpha, coarseAlpha) > refinementTolerance; }; - AdaptiveTimeStepper<NonlinearFactory, UpdaterPair, + AdaptiveTimeStepper<NonlinearFactory, MyUpdater, EnergyNorm<ScalarMatrix, ScalarVector>> adaptiveTimeStepper(factory, parset, myGlobalFriction, current, programState.relativeTime, programState.relativeTau, @@ -329,10 +329,10 @@ int main(int argc, char *argv[]) { programState.relativeTime = adaptiveTimeStepper.getRelativeTime(); programState.relativeTau = adaptiveTimeStepper.getRelativeTau(); - current.second->extractDisplacement(programState.u); - current.second->extractVelocity(programState.v); - current.second->extractAcceleration(programState.a); - current.first->extractAlpha(programState.alpha); + current.rate_->extractDisplacement(programState.u); + current.rate_->extractVelocity(programState.v); + current.rate_->extractAcceleration(programState.a); + current.state_->extractAlpha(programState.alpha); report(); } diff --git a/src/updaters.hh b/src/updaters.hh new file mode 100644 index 00000000..6d27f445 --- /dev/null +++ b/src/updaters.hh @@ -0,0 +1,23 @@ +#ifndef SRC_UPDATERS_HH +#define SRC_UPDATERS_HH + +template <class RateUpdaterTEMPLATE, class StateUpdaterTEMPLATE> +struct Updaters { + using RateUpdater = RateUpdaterTEMPLATE; + using StateUpdater = StateUpdaterTEMPLATE; + + Updaters() {} + + Updaters(std::shared_ptr<RateUpdaterTEMPLATE> rateUpdater, + std::shared_ptr<StateUpdaterTEMPLATE> stateUpdater) + : rate_(rateUpdater), state_(stateUpdater) {} + + Updaters<RateUpdater, StateUpdater> clone() const { + return Updaters<RateUpdater, StateUpdater>(rate_->clone(), state_->clone()); + } + + std::shared_ptr<RateUpdaterTEMPLATE> rate_; + std::shared_ptr<StateUpdaterTEMPLATE> state_; +}; + +#endif -- GitLab