diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc
index 2665eaa50690af7c2beed91e7897958a5244012c..efd781564e8c62db700c1f148b036083619e29d2 100644
--- a/src/sand-wedge.cc
+++ b/src/sand-wedge.cc
@@ -370,6 +370,16 @@ int main(int argc, char *argv[]) {
 
     auto const refinementTolerance =
         parset.get<double>("timeSteps.refinementTolerance");
+    auto const mustRefine = [&](UpdaterPair coarseUpdater,
+                                UpdaterPair fineUpdater) {
+      ScalarVector coarseAlpha;
+      coarseUpdater.first->extractAlpha(coarseAlpha);
+
+      ScalarVector fineAlpha;
+      fineUpdater.first->extractAlpha(fineAlpha);
+
+      return stateEnergyNorm.diff(fineAlpha, coarseAlpha) > refinementTolerance;
+    };
 
     size_t timeStep = 1;
 
@@ -394,7 +404,6 @@ int main(int argc, char *argv[]) {
       while (true) {
         R2 = clonePair(R1);
 
-        ScalarVector alphaR2;
         {
           CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
           coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
@@ -404,11 +413,9 @@ int main(int argc, char *argv[]) {
               coupledTimeStepper.step(relativeTime + relativeTau, relativeTau);
           iterationWriter << iterations << " " << std::flush;
         }
-        R2.first->extractAlpha(alphaR2);
 
         UpdaterPair C = clonePair(current);
 
-        ScalarVector alphaC;
         {
           CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
           coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
@@ -418,11 +425,8 @@ int main(int argc, char *argv[]) {
               coupledTimeStepper.step(relativeTime, 2.0 * relativeTau);
           iterationWriter << iterations << " " << std::flush;
         }
-        C.first->extractAlpha(alphaC);
-
-        auto const coarseningError = stateEnergyNorm.diff(alphaC, alphaR2);
 
-        if (coarseningError < refinementTolerance) {
+        if (!mustRefine(C, R2)) {
           R2 = { nullptr, nullptr };
           R1 = C;
           relativeTau *= 2.0;
@@ -433,12 +437,10 @@ int main(int argc, char *argv[]) {
       }
 
       if (!didCoarsen) {
-        ScalarVector alphaR1;
         while (true) {
           UpdaterPair F2 = clonePair(current);
           UpdaterPair F1;
 
-          ScalarVector alphaF2;
           {
             CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
             coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
@@ -455,11 +457,8 @@ int main(int argc, char *argv[]) {
                 relativeTime + relativeTau / 2.0, relativeTau / 2.0);
             iterationWriter << iterationsF2 << " " << std::flush;
           }
-          F2.first->extractAlpha(alphaF2);
-          R1.first->extractAlpha(alphaR1);
-          auto const refinementError = stateEnergyNorm.diff(alphaR1, alphaF2);
 
-          if (refinementError < refinementTolerance) {
+          if (!mustRefine(R1, F2)) {
             break;
           } else {
             R1 = F1;