diff --git a/src/adaptivetimestepper.hh b/src/adaptivetimestepper.hh index 9720efa64260adc0d027c8a3e0fc3505b6a5a110..2636cc1e70e1513ad1a1dd421b307f234f045cf3 100644 --- a/src/adaptivetimestepper.hh +++ b/src/adaptivetimestepper.hh @@ -1,29 +1,21 @@ #include "coupledtimestepper.hh" -template <typename T1, typename T2> -std::pair<T1, T2> clonePair(std::pair<T1, T2> in) { - return { in.first->clone(), in.second->clone() }; -} - -template <class Factory, class UpdaterPair, class ErrorNorm> +template <class Factory, class Updaters, class ErrorNorm> class AdaptiveTimeStepper { - using StateUpdater = typename UpdaterPair::first_type::element_type; - using RateUpdater = typename UpdaterPair::second_type::element_type; using Vector = typename Factory::Vector; using ConvexProblem = typename Factory::ConvexProblem; using Nonlinearity = typename ConvexProblem::NonlinearityType; - using MyCoupledTimeStepper = - CoupledTimeStepper<Factory, StateUpdater, RateUpdater, ErrorNorm>; + using MyCoupledTimeStepper = CoupledTimeStepper<Factory, Updaters, ErrorNorm>; public: - AdaptiveTimeStepper( - Factory &factory, Dune::ParameterTree const &parset, - std::shared_ptr<Nonlinearity> globalFriction, UpdaterPair ¤t, - double relativeTime, double relativeTau, - std::function<void(double, Vector &)> externalForces, - ErrorNorm const &errorNorm, - std::function<bool(UpdaterPair &, UpdaterPair &)> mustRefine) + AdaptiveTimeStepper(Factory &factory, Dune::ParameterTree const &parset, + std::shared_ptr<Nonlinearity> globalFriction, + Updaters ¤t, double relativeTime, + double relativeTau, + std::function<void(double, Vector &)> externalForces, + ErrorNorm const &errorNorm, + std::function<bool(Updaters &, Updaters &)> mustRefine) : finalTime_(parset.get<double>("problem.finalTime")), relativeTime_(relativeTime), relativeTau_(relativeTau), @@ -31,14 +23,14 @@ class AdaptiveTimeStepper { parset_(parset), globalFriction_(globalFriction), current_(current), - R1_(clonePair(current_)), + R1_(current_.clone()), externalForces_(externalForces), mustRefine_(mustRefine), iterationWriter_("iterations", std::fstream::out), errorNorm_(errorNorm) { - MyCoupledTimeStepper coupledTimeStepper( - finalTime_, factory_, parset_, globalFriction_, R1_.first, R1_.second, - errorNorm, externalForces_); + MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_, + globalFriction_, R1_, errorNorm, + externalForces_); stepAndReport("R1", coupledTimeStepper, relativeTime_, relativeTau_); iterationWriter_ << std::endl; } @@ -49,20 +41,20 @@ class AdaptiveTimeStepper { bool didCoarsen = false; while (relativeTime_ + relativeTau_ < 1.0) { - R2_ = clonePair(R1_); + R2_ = R1_.clone(); { - MyCoupledTimeStepper coupledTimeStepper( - finalTime_, factory_, parset_, globalFriction_, R2_.first, - R2_.second, errorNorm_, externalForces_); + MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_, + globalFriction_, R2_, + errorNorm_, externalForces_); stepAndReport("R2", coupledTimeStepper, relativeTime_ + relativeTau_, relativeTau_); } - UpdaterPair C = clonePair(current_); + Updaters C = current_.clone(); { - MyCoupledTimeStepper coupledTimeStepper( - finalTime_, factory_, parset_, globalFriction_, C.first, C.second, - errorNorm_, externalForces_); + MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_, + globalFriction_, C, errorNorm_, + externalForces_); stepAndReport("C", coupledTimeStepper, relativeTime_, 2.0 * relativeTau_); @@ -82,16 +74,16 @@ class AdaptiveTimeStepper { void refine() { while (true) { - UpdaterPair F2 = clonePair(current_); - UpdaterPair F1; + Updaters F2 = current_.clone(); + Updaters F1; { - MyCoupledTimeStepper coupledTimeStepper( - finalTime_, factory_, parset_, globalFriction_, F2.first, F2.second, - errorNorm_, externalForces_); + MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_, + globalFriction_, F2, errorNorm_, + externalForces_); stepAndReport("F1", coupledTimeStepper, relativeTime_, relativeTau_ / 2.0); - F1 = clonePair(F2); + F1 = F2.clone(); stepAndReport("F2", coupledTimeStepper, relativeTime_ + relativeTau_ / 2.0, relativeTau_ / 2.0); } @@ -133,11 +125,11 @@ class AdaptiveTimeStepper { Factory &factory_; Dune::ParameterTree const &parset_; std::shared_ptr<Nonlinearity> globalFriction_; - UpdaterPair ¤t_; - UpdaterPair R1_; - UpdaterPair R2_; + Updaters ¤t_; + Updaters R1_; + Updaters R2_; std::function<void(double, Vector &)> externalForces_; - std::function<bool(UpdaterPair &, UpdaterPair &)> mustRefine_; + std::function<bool(Updaters &, Updaters &)> mustRefine_; std::fstream iterationWriter_; ErrorNorm const &errorNorm_; }; diff --git a/src/coupledtimestepper.cc b/src/coupledtimestepper.cc index 1c12c5c99c0e4269853a23fc9e23845c4eb4da73..3be5824daf2f02c27e89fad45a2c76fa90fd3dbd 100644 --- a/src/coupledtimestepper.cc +++ b/src/coupledtimestepper.cc @@ -4,30 +4,26 @@ #include "coupledtimestepper.hh" -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> -CoupledTimeStepper<Factory, StateUpdater, RateUpdater, ErrorNorm>:: - CoupledTimeStepper(double finalTime, Factory &factory, - Dune::ParameterTree const &parset, - std::shared_ptr<Nonlinearity> globalFriction, - std::shared_ptr<StateUpdater> stateUpdater, - std::shared_ptr<RateUpdater> velocityUpdater, - ErrorNorm const &errorNorm, - std::function<void(double, Vector &)> externalForces) +template <class Factory, class Updaters, class ErrorNorm> +CoupledTimeStepper<Factory, Updaters, ErrorNorm>::CoupledTimeStepper( + double finalTime, Factory &factory, Dune::ParameterTree const &parset, + std::shared_ptr<Nonlinearity> globalFriction, Updaters updaters, + ErrorNorm const &errorNorm, + std::function<void(double, Vector &)> externalForces) : finalTime_(finalTime), factory_(factory), parset_(parset), globalFriction_(globalFriction), - stateUpdater_(stateUpdater), - velocityUpdater_(velocityUpdater), + updaters_(updaters), externalForces_(externalForces), errorNorm_(errorNorm) {} -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> -FixedPointIterationCounter CoupledTimeStepper< - Factory, StateUpdater, RateUpdater, ErrorNorm>::step(double relativeTime, - double relativeTau) { - stateUpdater_->nextTimeStep(); - velocityUpdater_->nextTimeStep(); +template <class Factory, class Updaters, class ErrorNorm> +FixedPointIterationCounter +CoupledTimeStepper<Factory, Updaters, ErrorNorm>::step(double relativeTime, + double relativeTau) { + updaters_.state_->nextTimeStep(); + updaters_.rate_->nextTimeStep(); auto const newRelativeTime = relativeTime + relativeTau; Vector ell; @@ -38,14 +34,13 @@ FixedPointIterationCounter CoupledTimeStepper< Vector velocityIterate; auto const tau = relativeTau * finalTime_; - stateUpdater_->setup(tau); - velocityUpdater_->setup(ell, tau, newRelativeTime, velocityRHS, - velocityIterate, velocityMatrix); - FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm> - fixedPointIterator(factory_, parset_, globalFriction_, errorNorm_); - auto const iterations = - fixedPointIterator.run(stateUpdater_, velocityUpdater_, velocityMatrix, - velocityRHS, velocityIterate); + updaters_.state_->setup(tau); + updaters_.rate_->setup(ell, tau, newRelativeTime, velocityRHS, + velocityIterate, velocityMatrix); + FixedPointIterator<Factory, Updaters, ErrorNorm> fixedPointIterator( + factory_, parset_, globalFriction_, errorNorm_); + auto const iterations = fixedPointIterator.run(updaters_, velocityMatrix, + velocityRHS, velocityIterate); return iterations; } diff --git a/src/coupledtimestepper.hh b/src/coupledtimestepper.hh index 98c4b0e971ccb44084a6a12dbd4ac334c3f0fd7f..8873e6757281e68c52dc0455f9a7264c43543042 100644 --- a/src/coupledtimestepper.hh +++ b/src/coupledtimestepper.hh @@ -7,7 +7,7 @@ #include <dune/common/parametertree.hh> #include "fixedpointiterator.hh" -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> +template <class Factory, class Updaters, class ErrorNorm> class CoupledTimeStepper { using Vector = typename Factory::Vector; using Matrix = typename Factory::Matrix; @@ -18,9 +18,7 @@ class CoupledTimeStepper { CoupledTimeStepper(double finalTime, Factory &factory, Dune::ParameterTree const &parset, std::shared_ptr<Nonlinearity> globalFriction, - std::shared_ptr<StateUpdater> stateUpdater, - std::shared_ptr<RateUpdater> velocityUpdater, - ErrorNorm const &errorNorm, + Updaters updaters, ErrorNorm const &errorNorm, std::function<void(double, Vector &)> externalForces); FixedPointIterationCounter step(double relativeTime, double relativeTau); @@ -30,8 +28,7 @@ class CoupledTimeStepper { Factory &factory_; Dune::ParameterTree const &parset_; std::shared_ptr<Nonlinearity> globalFriction_; - std::shared_ptr<StateUpdater> stateUpdater_; - std::shared_ptr<RateUpdater> velocityUpdater_; + Updaters updaters_; std::function<void(double, Vector &)> externalForces_; ErrorNorm const &errorNorm_; }; diff --git a/src/coupledtimestepper_tmpl.cc b/src/coupledtimestepper_tmpl.cc index 059922fd0f0208c4eb4f1b48da1703f5a6d9d5dd..f10309dbeae6a546f5475f06b6f2a41f9b3c9dac 100644 --- a/src/coupledtimestepper_tmpl.cc +++ b/src/coupledtimestepper_tmpl.cc @@ -16,6 +16,7 @@ #include "rate/rateupdater.hh" #include "solverfactory.hh" #include "state/stateupdater.hh" +#include "updaters.hh" using Function = Dune::VirtualFunction<double, double>; using Factory = SolverFactory< @@ -24,8 +25,8 @@ using Factory = SolverFactory< Grid>; using MyStateUpdater = StateUpdater<ScalarVector, Vector>; using MyRateUpdater = RateUpdater<Vector, Matrix, Function, MY_DIM>; +using MyUpdaters = Updaters<MyRateUpdater, MyStateUpdater>; using ErrorNorm = EnergyNorm<ScalarMatrix, ScalarVector>; -template class CoupledTimeStepper<Factory, MyStateUpdater, MyRateUpdater, - ErrorNorm>; +template class CoupledTimeStepper<Factory, MyUpdaters, ErrorNorm>; diff --git a/src/fixedpointiterator.cc b/src/fixedpointiterator.cc index 67caf40c3b6fc068ff29f05d0c29fcd05430391b..92c5c6cf6fbd0aac6a93679313c66663cc70eb7a 100644 --- a/src/fixedpointiterator.cc +++ b/src/fixedpointiterator.cc @@ -13,11 +13,10 @@ #include "fixedpointiterator.hh" -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> -FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>:: - FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset, - std::shared_ptr<Nonlinearity> globalFriction, - ErrorNorm const &errorNorm) +template <class Factory, class Updaters, class ErrorNorm> +FixedPointIterator<Factory, Updaters, ErrorNorm>::FixedPointIterator( + Factory &factory, Dune::ParameterTree const &parset, + std::shared_ptr<Nonlinearity> globalFriction, ErrorNorm const &errorNorm) : step_(factory.getStep()), parset_(parset), globalFriction_(globalFriction), @@ -29,12 +28,12 @@ FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>:: verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")), errorNorm_(errorNorm) {} -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> -FixedPointIterationCounter -FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::run( - std::shared_ptr<StateUpdater> stateUpdater, - std::shared_ptr<RateUpdater> velocityUpdater, Matrix const &velocityMatrix, - Vector const &velocityRHS, Vector &velocityIterate) { +template <class Factory, class Updaters, class ErrorNorm> +FixedPointIterationCounter FixedPointIterator< + Factory, Updaters, ErrorNorm>::run(Updaters updaters, + Matrix const &velocityMatrix, + Vector const &velocityRHS, + Vector &velocityIterate) { EnergyNorm<Matrix, Vector> energyNorm(velocityMatrix); LoopSolver<Vector> velocityProblemSolver(step_.get(), velocityMaxIterations_, velocityTolerance_, &energyNorm, @@ -43,7 +42,7 @@ FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::run( size_t fixedPointIteration; size_t multigridIterations = 0; ScalarVector alpha; - stateUpdater->extractAlpha(alpha); + updaters.state_->extractAlpha(alpha); for (fixedPointIteration = 0; fixedPointIteration < fixedPointMaxIterations_; ++fixedPointIteration) { // solve a velocity problem @@ -58,14 +57,14 @@ FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::run( multigridIterations += velocityProblemSolver.getResult().iterations; Vector v_m; - velocityUpdater->extractOldVelocity(v_m); + updaters.rate_->extractOldVelocity(v_m); v_m *= 1.0 - lambda_; Arithmetic::addProduct(v_m, lambda_, velocityIterate); // solve a state problem - stateUpdater->solve(v_m); + updaters.state_->solve(v_m); ScalarVector newAlpha; - stateUpdater->extractAlpha(newAlpha); + updaters.state_->extractAlpha(newAlpha); if (lambda_ < 1e-12 or errorNorm_.diff(alpha, newAlpha) < fixedPointTolerance_) { @@ -77,7 +76,7 @@ FixedPointIterator<Factory, StateUpdater, RateUpdater, ErrorNorm>::run( if (fixedPointIteration == fixedPointMaxIterations_) DUNE_THROW(Dune::Exception, "FPI failed to converge"); - velocityUpdater->postProcess(velocityIterate); + updaters.rate_->postProcess(velocityIterate); return { fixedPointIteration, multigridIterations }; } diff --git a/src/fixedpointiterator.hh b/src/fixedpointiterator.hh index 7543b12e27f52e9b25f7033d798e557862f5e07d..9a4a336fc5087fd026db4d7a20d35d90e801c483 100644 --- a/src/fixedpointiterator.hh +++ b/src/fixedpointiterator.hh @@ -16,9 +16,9 @@ struct FixedPointIterationCounter { std::ostream &operator<<(std::ostream &stream, FixedPointIterationCounter const &fpic); -template <class Factory, class StateUpdater, class RateUpdater, class ErrorNorm> +template <class Factory, class Updaters, class ErrorNorm> class FixedPointIterator { - using ScalarVector = typename StateUpdater::ScalarVector; + using ScalarVector = typename Updaters::StateUpdater::ScalarVector; using Vector = typename Factory::Vector; using Matrix = typename Factory::Matrix; using ConvexProblem = typename Factory::ConvexProblem; @@ -30,8 +30,7 @@ class FixedPointIterator { std::shared_ptr<Nonlinearity> globalFriction, ErrorNorm const &errorNorm_); - FixedPointIterationCounter run(std::shared_ptr<StateUpdater> stateUpdater, - std::shared_ptr<RateUpdater> velocityUpdater, + FixedPointIterationCounter run(Updaters updaters, Matrix const &velocityMatrix, Vector const &velocityRHS, Vector &velocityIterate); diff --git a/src/fixedpointiterator_tmpl.cc b/src/fixedpointiterator_tmpl.cc index c39287d6a800a80e22da26f430af6b702b8a3536..5975e197a9412e32eaee6fcb4a814b71a614bf9d 100644 --- a/src/fixedpointiterator_tmpl.cc +++ b/src/fixedpointiterator_tmpl.cc @@ -16,6 +16,7 @@ #include "rate/rateupdater.hh" #include "solverfactory.hh" #include "state/stateupdater.hh" +#include "updaters.hh" using Function = Dune::VirtualFunction<double, double>; using Factory = SolverFactory< @@ -24,8 +25,8 @@ using Factory = SolverFactory< Grid>; using MyStateUpdater = StateUpdater<ScalarVector, Vector>; using MyRateUpdater = RateUpdater<Vector, Matrix, Function, MY_DIM>; +using MyUpdaters = Updaters<MyRateUpdater, MyStateUpdater>; using ErrorNorm = EnergyNorm<ScalarMatrix, ScalarVector>; -template class FixedPointIterator<Factory, MyStateUpdater, MyRateUpdater, - ErrorNorm>; +template class FixedPointIterator<Factory, MyUpdaters, ErrorNorm>; diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc index c5559c6dcfcce7197ac19150707201a108473685..6c7588e859572b6ce76a327467761eaef898737a 100644 --- a/src/sand-wedge.cc +++ b/src/sand-wedge.cc @@ -70,6 +70,7 @@ #include "sand-wedge-data/weakpatch.hh" #include "solverfactory.hh" #include "state.hh" +#include "updaters.hh" #include "vtk.hh" size_t const dims = MY_DIM; @@ -291,33 +292,32 @@ int main(int argc, char *argv[]) { Grid>; NonlinearFactory factory(parset.sub("solver.tnnmg"), *grid, dirichletNodes); - using UpdaterPair = - std::pair<std::shared_ptr<StateUpdater<ScalarVector, Vector>>, - std::shared_ptr<RateUpdater<Vector, Matrix, Function, dims>>>; - UpdaterPair current( + using MyUpdater = Updaters<RateUpdater<Vector, Matrix, Function, dims>, + StateUpdater<ScalarVector, Vector>>; + MyUpdater current( + initRateUpdater(parset.get<Config::scheme>("timeSteps.scheme"), + velocityDirichletFunction, dirichletNodes, matrices, + programState.u, programState.v, programState.a), initStateUpdater<ScalarVector, Vector>( parset.get<Config::stateModel>("boundary.friction.stateModel"), programState.alpha, frictionalNodes, parset.get<double>("boundary.friction.L"), - parset.get<double>("boundary.friction.V0")), - initRateUpdater(parset.get<Config::scheme>("timeSteps.scheme"), - velocityDirichletFunction, dirichletNodes, matrices, - programState.u, programState.v, programState.a)); + parset.get<double>("boundary.friction.V0"))); auto const refinementTolerance = parset.get<double>("timeSteps.refinementTolerance"); - auto const mustRefine = [&](UpdaterPair &coarseUpdater, - UpdaterPair &fineUpdater) { + auto const mustRefine = [&](MyUpdater &coarseUpdater, + MyUpdater &fineUpdater) { ScalarVector coarseAlpha; - coarseUpdater.first->extractAlpha(coarseAlpha); + coarseUpdater.state_->extractAlpha(coarseAlpha); ScalarVector fineAlpha; - fineUpdater.first->extractAlpha(fineAlpha); + fineUpdater.state_->extractAlpha(fineAlpha); return stateEnergyNorm.diff(fineAlpha, coarseAlpha) > refinementTolerance; }; - AdaptiveTimeStepper<NonlinearFactory, UpdaterPair, + AdaptiveTimeStepper<NonlinearFactory, MyUpdater, EnergyNorm<ScalarMatrix, ScalarVector>> adaptiveTimeStepper(factory, parset, myGlobalFriction, current, programState.relativeTime, programState.relativeTau, @@ -329,10 +329,10 @@ int main(int argc, char *argv[]) { programState.relativeTime = adaptiveTimeStepper.getRelativeTime(); programState.relativeTau = adaptiveTimeStepper.getRelativeTau(); - current.second->extractDisplacement(programState.u); - current.second->extractVelocity(programState.v); - current.second->extractAcceleration(programState.a); - current.first->extractAlpha(programState.alpha); + current.rate_->extractDisplacement(programState.u); + current.rate_->extractVelocity(programState.v); + current.rate_->extractAcceleration(programState.a); + current.state_->extractAlpha(programState.alpha); report(); } diff --git a/src/updaters.hh b/src/updaters.hh new file mode 100644 index 0000000000000000000000000000000000000000..6d27f445b301eacb9d8b73eb48fa9f6cb4764da7 --- /dev/null +++ b/src/updaters.hh @@ -0,0 +1,23 @@ +#ifndef SRC_UPDATERS_HH +#define SRC_UPDATERS_HH + +template <class RateUpdaterTEMPLATE, class StateUpdaterTEMPLATE> +struct Updaters { + using RateUpdater = RateUpdaterTEMPLATE; + using StateUpdater = StateUpdaterTEMPLATE; + + Updaters() {} + + Updaters(std::shared_ptr<RateUpdaterTEMPLATE> rateUpdater, + std::shared_ptr<StateUpdaterTEMPLATE> stateUpdater) + : rate_(rateUpdater), state_(stateUpdater) {} + + Updaters<RateUpdater, StateUpdater> clone() const { + return Updaters<RateUpdater, StateUpdater>(rate_->clone(), state_->clone()); + } + + std::shared_ptr<RateUpdaterTEMPLATE> rate_; + std::shared_ptr<StateUpdaterTEMPLATE> state_; +}; + +#endif