diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc index 20406ca4199e9d7df55534bed6453f1702a345cf..2665eaa50690af7c2beed91e7897958a5244012c 100644 --- a/src/sand-wedge.cc +++ b/src/sand-wedge.cc @@ -85,6 +85,11 @@ size_t const dims = DIM; +template <typename T1, typename T2> +std::pair<T1, T2> clonePair(std::pair<T1, T2> in) { + return { in.first->clone(), in.second->clone() }; +} + void initPython() { Python::start(); @@ -341,15 +346,19 @@ int main(int argc, char *argv[]) { NonlinearFactory factory(parset.sub("solver.tnnmg"), refinements, *grid, dirichletNodes); - auto velocityUpdater = initTimeStepper( - parset.get<Config::scheme>("timeSteps.scheme"), - velocityDirichletFunction, dirichletNodes, M, A, C, u_initial, - ur_initial, v_initial, vr_initial, a_initial); - auto stateUpdater = initStateUpdater<ScalarVector, Vector>( - parset.get<Config::stateModel>("boundary.friction.stateModel"), - alpha_initial, frictionalNodes, - parset.get<double>("boundary.friction.L"), - parset.get<double>("boundary.friction.V0")); + using UpdaterPair = std::pair< + std::shared_ptr<StateUpdater<ScalarVector, Vector>>, + std::shared_ptr<TimeSteppingScheme<Vector, Matrix, Function, dims>>>; + UpdaterPair current( + initStateUpdater<ScalarVector, Vector>( + parset.get<Config::stateModel>("boundary.friction.stateModel"), + alpha_initial, frictionalNodes, + parset.get<double>("boundary.friction.L"), + parset.get<double>("boundary.friction.V0")), + initTimeStepper(parset.get<Config::scheme>("timeSteps.scheme"), + velocityDirichletFunction, dirichletNodes, M, A, C, + u_initial, ur_initial, v_initial, vr_initial, + a_initial)); auto const finalTime = parset.get<double>("problem.finalTime"); double relativeTime = 0.0; @@ -366,66 +375,56 @@ int main(int argc, char *argv[]) { std::fstream iterationWriter("iterations", std::fstream::out); - auto stateUpdaterR1 = stateUpdater->clone(); - auto velocityUpdaterR1 = velocityUpdater->clone(); + UpdaterPair R1 = clonePair(current); { CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper> - coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, - stateUpdaterR1, velocityUpdaterR1, - computeExternalForces); + coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, R1.first, + R1.second, computeExternalForces); iterationWriter << "R1 "; auto const iterations = coupledTimeStepper.step(relativeTime, relativeTau); iterationWriter << iterations << std::endl; } - std::shared_ptr<MyStateUpdater> stateUpdaterR2 = nullptr; - std::shared_ptr<MyTimeStepper> velocityUpdaterR2 = nullptr; + UpdaterPair R2; while (relativeTime < 1.0 - 1e-10) { bool didCoarsen = false; while (true) { - stateUpdaterR2 = stateUpdaterR1->clone(); - velocityUpdaterR2 = velocityUpdaterR1->clone(); + R2 = clonePair(R1); ScalarVector alphaR2; { CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper> coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, - stateUpdaterR2, velocityUpdaterR2, - computeExternalForces); + R2.first, R2.second, computeExternalForces); iterationWriter << "R2 "; auto const iterations = coupledTimeStepper.step(relativeTime + relativeTau, relativeTau); iterationWriter << iterations << " " << std::flush; } - stateUpdaterR2->extractAlpha(alphaR2); + R2.first->extractAlpha(alphaR2); - auto stateUpdaterC = stateUpdater->clone(); - auto velocityUpdaterC = velocityUpdater->clone(); + UpdaterPair C = clonePair(current); ScalarVector alphaC; { CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper> coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, - stateUpdaterC, velocityUpdaterC, - computeExternalForces); + C.first, C.second, computeExternalForces); iterationWriter << "C "; auto const iterations = coupledTimeStepper.step(relativeTime, 2.0 * relativeTau); iterationWriter << iterations << " " << std::flush; } - stateUpdaterC->extractAlpha(alphaC); + C.first->extractAlpha(alphaC); auto const coarseningError = stateEnergyNorm.diff(alphaC, alphaR2); if (coarseningError < refinementTolerance) { - stateUpdaterR2 = nullptr; - velocityUpdaterR2 = nullptr; - - stateUpdaterR1 = stateUpdaterC; - velocityUpdaterR1 = velocityUpdaterC; + R2 = { nullptr, nullptr }; + R1 = C; relativeTau *= 2.0; didCoarsen = true; } else { @@ -436,42 +435,35 @@ int main(int argc, char *argv[]) { if (!didCoarsen) { ScalarVector alphaR1; while (true) { - auto stateUpdaterF2 = stateUpdater->clone(); - auto velocityUpdaterF2 = velocityUpdater->clone(); - - std::shared_ptr<MyStateUpdater> stateUpdaterF1; - std::shared_ptr<MyTimeStepper> velocityUpdaterF1; + UpdaterPair F2 = clonePair(current); + UpdaterPair F1; ScalarVector alphaF2; { CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper> coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, - stateUpdaterF2, velocityUpdaterF2, - computeExternalForces); + F2.first, F2.second, computeExternalForces); iterationWriter << "F1 "; auto const iterationsF1 = coupledTimeStepper.step(relativeTime, relativeTau / 2.0); iterationWriter << iterationsF1 << " " << std::flush; - stateUpdaterF1 = stateUpdaterF2->clone(); - velocityUpdaterF1 = velocityUpdaterF2->clone(); + F1 = clonePair(F2); iterationWriter << "F2 "; auto const iterationsF2 = coupledTimeStepper.step( relativeTime + relativeTau / 2.0, relativeTau / 2.0); iterationWriter << iterationsF2 << " " << std::flush; } - stateUpdaterF2->extractAlpha(alphaF2); - stateUpdaterR1->extractAlpha(alphaR1); + F2.first->extractAlpha(alphaF2); + R1.first->extractAlpha(alphaR1); auto const refinementError = stateEnergyNorm.diff(alphaR1, alphaF2); if (refinementError < refinementTolerance) { break; } else { - stateUpdaterR1 = stateUpdaterF1; - velocityUpdaterR1 = velocityUpdaterF1; - stateUpdaterR2 = stateUpdaterF2; - velocityUpdaterR2 = velocityUpdaterF2; + R1 = F1; + R2 = F2; relativeTau /= 2.0; } } @@ -480,18 +472,16 @@ int main(int argc, char *argv[]) { reportTimeStep(relativeTime, relativeTau); - stateUpdater = stateUpdaterR1; - velocityUpdater = velocityUpdaterR1; - stateUpdaterR1 = stateUpdaterR2; - velocityUpdaterR1 = velocityUpdaterR2; + current = R1; + R1 = R2; relativeTime += relativeTau; Vector u, ur, vr; ScalarVector alpha; - velocityUpdater->extractDisplacement(u); - velocityUpdater->extractRelativeDisplacement(ur); - velocityUpdater->extractRelativeVelocity(vr); - stateUpdater->extractAlpha(alpha); + current.second->extractDisplacement(u); + current.second->extractRelativeDisplacement(ur); + current.second->extractRelativeVelocity(vr); + current.first->extractAlpha(alpha); report(ur, vr, alpha); {