diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc
index 20406ca4199e9d7df55534bed6453f1702a345cf..2665eaa50690af7c2beed91e7897958a5244012c 100644
--- a/src/sand-wedge.cc
+++ b/src/sand-wedge.cc
@@ -85,6 +85,11 @@
 
 size_t const dims = DIM;
 
+template <typename T1, typename T2>
+std::pair<T1, T2> clonePair(std::pair<T1, T2> in) {
+  return { in.first->clone(), in.second->clone() };
+}
+
 void initPython() {
   Python::start();
 
@@ -341,15 +346,19 @@ int main(int argc, char *argv[]) {
     NonlinearFactory factory(parset.sub("solver.tnnmg"), refinements, *grid,
                              dirichletNodes);
 
-    auto velocityUpdater = initTimeStepper(
-        parset.get<Config::scheme>("timeSteps.scheme"),
-        velocityDirichletFunction, dirichletNodes, M, A, C, u_initial,
-        ur_initial, v_initial, vr_initial, a_initial);
-    auto stateUpdater = initStateUpdater<ScalarVector, Vector>(
-        parset.get<Config::stateModel>("boundary.friction.stateModel"),
-        alpha_initial, frictionalNodes,
-        parset.get<double>("boundary.friction.L"),
-        parset.get<double>("boundary.friction.V0"));
+    using UpdaterPair = std::pair<
+        std::shared_ptr<StateUpdater<ScalarVector, Vector>>,
+        std::shared_ptr<TimeSteppingScheme<Vector, Matrix, Function, dims>>>;
+    UpdaterPair current(
+        initStateUpdater<ScalarVector, Vector>(
+            parset.get<Config::stateModel>("boundary.friction.stateModel"),
+            alpha_initial, frictionalNodes,
+            parset.get<double>("boundary.friction.L"),
+            parset.get<double>("boundary.friction.V0")),
+        initTimeStepper(parset.get<Config::scheme>("timeSteps.scheme"),
+                        velocityDirichletFunction, dirichletNodes, M, A, C,
+                        u_initial, ur_initial, v_initial, vr_initial,
+                        a_initial));
 
     auto const finalTime = parset.get<double>("problem.finalTime");
     double relativeTime = 0.0;
@@ -366,66 +375,56 @@ int main(int argc, char *argv[]) {
 
     std::fstream iterationWriter("iterations", std::fstream::out);
 
-    auto stateUpdaterR1 = stateUpdater->clone();
-    auto velocityUpdaterR1 = velocityUpdater->clone();
+    UpdaterPair R1 = clonePair(current);
     {
       CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
-      coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
-                         stateUpdaterR1, velocityUpdaterR1,
-                         computeExternalForces);
+      coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, R1.first,
+                         R1.second, computeExternalForces);
       iterationWriter << "R1 ";
       auto const iterations =
           coupledTimeStepper.step(relativeTime, relativeTau);
       iterationWriter << iterations << std::endl;
     }
 
-    std::shared_ptr<MyStateUpdater> stateUpdaterR2 = nullptr;
-    std::shared_ptr<MyTimeStepper> velocityUpdaterR2 = nullptr;
+    UpdaterPair R2;
 
     while (relativeTime < 1.0 - 1e-10) {
       bool didCoarsen = false;
 
       while (true) {
-        stateUpdaterR2 = stateUpdaterR1->clone();
-        velocityUpdaterR2 = velocityUpdaterR1->clone();
+        R2 = clonePair(R1);
 
         ScalarVector alphaR2;
         {
           CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
           coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
-                             stateUpdaterR2, velocityUpdaterR2,
-                             computeExternalForces);
+                             R2.first, R2.second, computeExternalForces);
           iterationWriter << "R2 ";
           auto const iterations =
               coupledTimeStepper.step(relativeTime + relativeTau, relativeTau);
           iterationWriter << iterations << " " << std::flush;
         }
-        stateUpdaterR2->extractAlpha(alphaR2);
+        R2.first->extractAlpha(alphaR2);
 
-        auto stateUpdaterC = stateUpdater->clone();
-        auto velocityUpdaterC = velocityUpdater->clone();
+        UpdaterPair C = clonePair(current);
 
         ScalarVector alphaC;
         {
           CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
           coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
-                             stateUpdaterC, velocityUpdaterC,
-                             computeExternalForces);
+                             C.first, C.second, computeExternalForces);
           iterationWriter << "C ";
           auto const iterations =
               coupledTimeStepper.step(relativeTime, 2.0 * relativeTau);
           iterationWriter << iterations << " " << std::flush;
         }
-        stateUpdaterC->extractAlpha(alphaC);
+        C.first->extractAlpha(alphaC);
 
         auto const coarseningError = stateEnergyNorm.diff(alphaC, alphaR2);
 
         if (coarseningError < refinementTolerance) {
-          stateUpdaterR2 = nullptr;
-          velocityUpdaterR2 = nullptr;
-
-          stateUpdaterR1 = stateUpdaterC;
-          velocityUpdaterR1 = velocityUpdaterC;
+          R2 = { nullptr, nullptr };
+          R1 = C;
           relativeTau *= 2.0;
           didCoarsen = true;
         } else {
@@ -436,42 +435,35 @@ int main(int argc, char *argv[]) {
       if (!didCoarsen) {
         ScalarVector alphaR1;
         while (true) {
-          auto stateUpdaterF2 = stateUpdater->clone();
-          auto velocityUpdaterF2 = velocityUpdater->clone();
-
-          std::shared_ptr<MyStateUpdater> stateUpdaterF1;
-          std::shared_ptr<MyTimeStepper> velocityUpdaterF1;
+          UpdaterPair F2 = clonePair(current);
+          UpdaterPair F1;
 
           ScalarVector alphaF2;
           {
             CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
             coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
-                               stateUpdaterF2, velocityUpdaterF2,
-                               computeExternalForces);
+                               F2.first, F2.second, computeExternalForces);
             iterationWriter << "F1 ";
             auto const iterationsF1 =
                 coupledTimeStepper.step(relativeTime, relativeTau / 2.0);
             iterationWriter << iterationsF1 << " " << std::flush;
 
-            stateUpdaterF1 = stateUpdaterF2->clone();
-            velocityUpdaterF1 = velocityUpdaterF2->clone();
+            F1 = clonePair(F2);
 
             iterationWriter << "F2 ";
             auto const iterationsF2 = coupledTimeStepper.step(
                 relativeTime + relativeTau / 2.0, relativeTau / 2.0);
             iterationWriter << iterationsF2 << " " << std::flush;
           }
-          stateUpdaterF2->extractAlpha(alphaF2);
-          stateUpdaterR1->extractAlpha(alphaR1);
+          F2.first->extractAlpha(alphaF2);
+          R1.first->extractAlpha(alphaR1);
           auto const refinementError = stateEnergyNorm.diff(alphaR1, alphaF2);
 
           if (refinementError < refinementTolerance) {
             break;
           } else {
-            stateUpdaterR1 = stateUpdaterF1;
-            velocityUpdaterR1 = velocityUpdaterF1;
-            stateUpdaterR2 = stateUpdaterF2;
-            velocityUpdaterR2 = velocityUpdaterF2;
+            R1 = F1;
+            R2 = F2;
             relativeTau /= 2.0;
           }
         }
@@ -480,18 +472,16 @@ int main(int argc, char *argv[]) {
 
       reportTimeStep(relativeTime, relativeTau);
 
-      stateUpdater = stateUpdaterR1;
-      velocityUpdater = velocityUpdaterR1;
-      stateUpdaterR1 = stateUpdaterR2;
-      velocityUpdaterR1 = velocityUpdaterR2;
+      current = R1;
+      R1 = R2;
       relativeTime += relativeTau;
 
       Vector u, ur, vr;
       ScalarVector alpha;
-      velocityUpdater->extractDisplacement(u);
-      velocityUpdater->extractRelativeDisplacement(ur);
-      velocityUpdater->extractRelativeVelocity(vr);
-      stateUpdater->extractAlpha(alpha);
+      current.second->extractDisplacement(u);
+      current.second->extractRelativeDisplacement(ur);
+      current.second->extractRelativeVelocity(vr);
+      current.first->extractAlpha(alpha);
 
       report(ur, vr, alpha);
       {