diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc
index ef45894878a68b392570b5d6187d57591a0e6861..0b77d199495ef46a6d68395aa9933e7292aac1aa 100644
--- a/src/sand-wedge.cc
+++ b/src/sand-wedge.cc
@@ -178,6 +178,62 @@ class FixedPointIterator {
   Solver::VerbosityMode verbosity_;
 };
 
+template <class Factory, class StateUpdater, class VelocityUpdater>
+class CoupledTimeStepper {
+  using Vector = typename Factory::Vector;
+  using Matrix = typename Factory::Matrix;
+  using ConvexProblem = typename Factory::ConvexProblem;
+  using Nonlinearity = typename ConvexProblem::NonlinearityType;
+
+public:
+  CoupledTimeStepper(double finalTime, Factory &factory,
+                     Dune::ParameterTree const &parset,
+                     std::shared_ptr<Nonlinearity> globalFriction,
+                     std::shared_ptr<StateUpdater> stateUpdater,
+                     std::shared_ptr<VelocityUpdater> velocityUpdater,
+                     std::function<void(double, Vector &)> externalForces)
+      : finalTime_(finalTime),
+        factory_(factory),
+        parset_(parset),
+        globalFriction_(globalFriction),
+        stateUpdater_(stateUpdater),
+        velocityUpdater_(velocityUpdater),
+        externalForces_(externalForces) {}
+
+  void step(double relativeTime, double relativeTau) {
+    stateUpdater_->nextTimeStep();
+    velocityUpdater_->nextTimeStep();
+
+    auto const newRelativeTime = relativeTime + relativeTau;
+    Vector ell;
+    externalForces_(newRelativeTime, ell);
+
+    Matrix velocityMatrix;
+    Vector velocityRHS;
+    Vector velocityIterate;
+
+    auto const tau = relativeTau * finalTime_;
+    stateUpdater_->setup(tau);
+    velocityUpdater_->setup(ell, tau, newRelativeTime, velocityRHS,
+                            velocityIterate, velocityMatrix);
+    EnergyNorm<Matrix, Vector> const velocityMatrixNorm(velocityMatrix);
+
+    FixedPointIterator<Factory, StateUpdater, VelocityUpdater>
+    fixedPointIterator(factory_, parset_, globalFriction_);
+    fixedPointIterator.run(stateUpdater_, velocityUpdater_, velocityMatrix,
+                           velocityMatrixNorm, velocityRHS, velocityIterate);
+  }
+
+private:
+  double finalTime_;
+  Factory &factory_;
+  Dune::ParameterTree const &parset_;
+  std::shared_ptr<Nonlinearity> globalFriction_;
+  std::shared_ptr<StateUpdater> stateUpdater_;
+  std::shared_ptr<VelocityUpdater> velocityUpdater_;
+  std::function<void(double, Vector &)> externalForces_;
+};
+
 int main(int argc, char *argv[]) {
   try {
     Dune::ParameterTree parset;
@@ -276,7 +332,8 @@ int main(int argc, char *argv[]) {
     myAssembler.assembleBodyForce(body.getGravityField(), gravityFunctional);
 
     // Problem formulation: right-hand side
-    auto const computeExternalForces = [&](double _relativeTime, Vector &_ell) {
+    std::function<void(double, Vector &)> computeExternalForces = [&](
+        double _relativeTime, Vector &_ell) {
       myAssembler.assembleNeumann(neumannBoundary, _ell, neumannFunction,
                                   _relativeTime);
       _ell += gravityFunctional;
@@ -430,30 +487,16 @@ int main(int argc, char *argv[]) {
         parset.get<double>("boundary.friction.L"),
         parset.get<double>("boundary.friction.V0"));
 
+    auto const finalTime = parset.get<double>("problem.finalTime");
     auto const timeSteps = parset.get<size_t>("timeSteps.number");
-    auto const tau = parset.get<double>("problem.finalTime") / timeSteps;
+    auto const relativeTau = 1.0 / timeSteps;
     for (size_t timeStep = 1; timeStep <= timeSteps; ++timeStep) {
-      stateUpdater->nextTimeStep();
-      velocityUpdater->nextTimeStep();
-
       auto const relativeTime = double(timeStep) / double(timeSteps);
-      Vector ell;
-      computeExternalForces(relativeTime, ell);
-
-      Matrix velocityMatrix;
-      Vector velocityRHS;
-      Vector velocityIterate;
-
-      stateUpdater->setup(tau);
-      velocityUpdater->setup(ell, tau, relativeTime, velocityRHS,
-                             velocityIterate, velocityMatrix);
-      EnergyNorm<Matrix, Vector> const velocityMatrixNorm(velocityMatrix);
-
-      FixedPointIterator<NonlinearFactory, StateUpdater<ScalarVector, Vector>,
+      CoupledTimeStepper<NonlinearFactory, StateUpdater<ScalarVector, Vector>,
                          TimeSteppingScheme<Vector, Matrix, Function, dims>>
-      fixedPointIterator(factory, parset, myGlobalFriction);
-      fixedPointIterator.run(stateUpdater, velocityUpdater, velocityMatrix,
-                             velocityMatrixNorm, velocityRHS, velocityIterate);
+      coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
+                         stateUpdater, velocityUpdater, computeExternalForces);
+      coupledTimeStepper.step(relativeTime, relativeTau);
 
       Vector u, ur, vr;
       ScalarVector alpha;