diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc index 2665eaa50690af7c2beed91e7897958a5244012c..efd781564e8c62db700c1f148b036083619e29d2 100644 --- a/src/sand-wedge.cc +++ b/src/sand-wedge.cc @@ -370,6 +370,16 @@ int main(int argc, char *argv[]) { auto const refinementTolerance = parset.get<double>("timeSteps.refinementTolerance"); + auto const mustRefine = [&](UpdaterPair coarseUpdater, + UpdaterPair fineUpdater) { + ScalarVector coarseAlpha; + coarseUpdater.first->extractAlpha(coarseAlpha); + + ScalarVector fineAlpha; + fineUpdater.first->extractAlpha(fineAlpha); + + return stateEnergyNorm.diff(fineAlpha, coarseAlpha) > refinementTolerance; + }; size_t timeStep = 1; @@ -394,7 +404,6 @@ int main(int argc, char *argv[]) { while (true) { R2 = clonePair(R1); - ScalarVector alphaR2; { CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper> coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, @@ -404,11 +413,9 @@ int main(int argc, char *argv[]) { coupledTimeStepper.step(relativeTime + relativeTau, relativeTau); iterationWriter << iterations << " " << std::flush; } - R2.first->extractAlpha(alphaR2); UpdaterPair C = clonePair(current); - ScalarVector alphaC; { CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper> coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, @@ -418,11 +425,8 @@ int main(int argc, char *argv[]) { coupledTimeStepper.step(relativeTime, 2.0 * relativeTau); iterationWriter << iterations << " " << std::flush; } - C.first->extractAlpha(alphaC); - - auto const coarseningError = stateEnergyNorm.diff(alphaC, alphaR2); - if (coarseningError < refinementTolerance) { + if (!mustRefine(C, R2)) { R2 = { nullptr, nullptr }; R1 = C; relativeTau *= 2.0; @@ -433,12 +437,10 @@ int main(int argc, char *argv[]) { } if (!didCoarsen) { - ScalarVector alphaR1; while (true) { UpdaterPair F2 = clonePair(current); UpdaterPair F1; - ScalarVector alphaF2; { CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper> coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, @@ -455,11 +457,8 @@ int main(int argc, char *argv[]) { relativeTime + relativeTau / 2.0, relativeTau / 2.0); iterationWriter << iterationsF2 << " " << std::flush; } - F2.first->extractAlpha(alphaF2); - R1.first->extractAlpha(alphaR1); - auto const refinementError = stateEnergyNorm.diff(alphaR1, alphaF2); - if (refinementError < refinementTolerance) { + if (!mustRefine(R1, F2)) { break; } else { R1 = F1;