diff --git a/src/adaptivetimestepper.hh b/src/adaptivetimestepper.hh
new file mode 100644
index 0000000000000000000000000000000000000000..0d191337af74ef9ea619c6cffc0c0d2f5d85d02f
--- /dev/null
+++ b/src/adaptivetimestepper.hh
@@ -0,0 +1,142 @@
+#include "coupledtimestepper.hh"
+
+template <typename T1, typename T2>
+std::pair<T1, T2> clonePair(std::pair<T1, T2> in) {
+  return { in.first->clone(), in.second->clone() };
+}
+
+template <class Factory, class UpdaterPair> class AdaptiveTimeStepper {
+  using StateUpdater = typename UpdaterPair::first_type::element_type;
+  using VelocityUpdater = typename UpdaterPair::second_type::element_type;
+  using Vector = typename Factory::Vector;
+  using ConvexProblem = typename Factory::ConvexProblem;
+  using Nonlinearity = typename ConvexProblem::NonlinearityType;
+
+  using MyCoupledTimeStepper =
+      CoupledTimeStepper<Factory, StateUpdater, VelocityUpdater>;
+
+public:
+  AdaptiveTimeStepper(
+      Factory &factory, Dune::ParameterTree const &parset,
+      std::shared_ptr<Nonlinearity> globalFriction, UpdaterPair &current,
+      std::function<void(double, Vector &)> externalForces,
+      std::function<bool(UpdaterPair &, UpdaterPair &)> mustRefine)
+      : finalTime_(parset.get<double>("problem.finalTime")),
+        relativeTime_(0.0),
+        relativeTau_(1e-6), // FIXME (not really important, though)
+        factory_(factory),
+        parset_(parset),
+        globalFriction_(globalFriction),
+        current_(current),
+        R1_(clonePair(current_)),
+        externalForces_(externalForces),
+        mustRefine_(mustRefine),
+        iterationWriter_("iterations", std::fstream::out) {
+    MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_,
+                                            globalFriction_, R1_.first,
+                                            R1_.second, externalForces_);
+    stepAndReport("R1", coupledTimeStepper, relativeTime_, relativeTau_);
+    iterationWriter_ << std::endl;
+  }
+
+  // FIXME
+  bool reachedEnd() { return relativeTime_ > 1.0 - 1e-10; }
+
+  bool coarsen() {
+    bool didCoarsen = false;
+
+    // FIXME: for a constant function, e.g., we will not only overstep but
+    // diverge
+    while (true) {
+      R2_ = clonePair(R1_);
+      {
+        MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_,
+                                                globalFriction_, R2_.first,
+                                                R2_.second, externalForces_);
+        stepAndReport("R2", coupledTimeStepper, relativeTime_ + relativeTau_,
+                      relativeTau_);
+      }
+
+      UpdaterPair C = clonePair(current_);
+      {
+        MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_,
+                                                globalFriction_, C.first,
+                                                C.second, externalForces_);
+
+        stepAndReport("C", coupledTimeStepper, relativeTime_,
+                      2.0 * relativeTau_);
+      }
+
+      if (!mustRefine_(C, R2_)) {
+        R2_ = { nullptr, nullptr };
+        R1_ = C;
+        relativeTau_ *= 2.0;
+        didCoarsen = true;
+      } else {
+        break;
+      }
+    }
+    return didCoarsen;
+  }
+
+  void refine() {
+    while (true) {
+      UpdaterPair F2 = clonePair(current_);
+      UpdaterPair F1;
+      {
+        MyCoupledTimeStepper coupledTimeStepper(finalTime_, factory_, parset_,
+                                                globalFriction_, F2.first,
+                                                F2.second, externalForces_);
+        stepAndReport("F1", coupledTimeStepper, relativeTime_,
+                      relativeTau_ / 2.0);
+
+        F1 = clonePair(F2);
+        stepAndReport("F2", coupledTimeStepper,
+                      relativeTime_ + relativeTau_ / 2.0, relativeTau_ / 2.0);
+      }
+
+      if (!mustRefine_(R1_, F2)) {
+        break;
+      } else {
+        R1_ = F1;
+        R2_ = F2;
+        relativeTau_ /= 2.0;
+      }
+    }
+  }
+
+  void advance() {
+    auto const didCoarsen = coarsen();
+    if (!didCoarsen)
+      refine();
+
+    iterationWriter_ << std::endl;
+
+    current_ = R1_;
+    R1_ = R2_;
+    relativeTime_ += relativeTau_;
+  }
+
+  double getRelativeTime() { return relativeTime_; }
+  double getRelativeTau() { return relativeTau_; }
+
+private:
+  void stepAndReport(std::string type, MyCoupledTimeStepper &stepper,
+                     double rTime, double rTau) {
+    iterationWriter_ << type << " " << stepper.step(rTime, rTau) << " "
+                     << std::flush;
+  }
+
+  double finalTime_;
+  double relativeTime_;
+  double relativeTau_;
+  Factory &factory_;
+  Dune::ParameterTree const &parset_;
+  std::shared_ptr<Nonlinearity> globalFriction_;
+  UpdaterPair &current_;
+  UpdaterPair R1_;
+  UpdaterPair R2_;
+  std::function<void(double, Vector &)> externalForces_;
+  std::function<bool(UpdaterPair &, UpdaterPair &)> mustRefine_;
+  std::fstream iterationWriter_;
+};
diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc
index efd781564e8c62db700c1f148b036083619e29d2..7b9ecf0f2b537fa8496a0fdc57d158c236db022f 100644
--- a/src/sand-wedge.cc
+++ b/src/sand-wedge.cc
@@ -66,6 +66,7 @@
 #include <dune/tectonic/myblockproblem.hh>
 #include <dune/tectonic/globalfriction.hh>
 
+#include "adaptivetimestepper.hh" // FIXME
 #include "assemblers.hh"
 #include "tobool.hh"
 #include "coupledtimestepper.hh"
@@ -85,11 +86,6 @@
 
 size_t const dims = DIM;
 
-template <typename T1, typename T2>
-std::pair<T1, T2> clonePair(std::pair<T1, T2> in) {
-  return { in.first->clone(), in.second->clone() };
-}
-
 void initPython() {
   Python::start();
 
@@ -360,18 +356,10 @@ int main(int argc, char *argv[]) {
                         u_initial, ur_initial, v_initial, vr_initial,
                         a_initial));
 
-    auto const finalTime = parset.get<double>("problem.finalTime");
-    double relativeTime = 0.0;
-
-    double relativeTau = 1e-6; // FIXME (not really important, though)
-
-    using MyStateUpdater = StateUpdater<ScalarVector, Vector>;
-    using MyTimeStepper = TimeSteppingScheme<Vector, Matrix, Function, dims>;
-
     auto const refinementTolerance =
         parset.get<double>("timeSteps.refinementTolerance");
-    auto const mustRefine = [&](UpdaterPair coarseUpdater,
-                                UpdaterPair fineUpdater) {
+    auto const mustRefine = [&](UpdaterPair &coarseUpdater,
+                                UpdaterPair &fineUpdater) {
       ScalarVector coarseAlpha;
       coarseUpdater.first->extractAlpha(coarseAlpha);
 
@@ -383,97 +371,13 @@ int main(int argc, char *argv[]) {
 
     size_t timeStep = 1;
 
-    std::fstream iterationWriter("iterations", std::fstream::out);
-
-    UpdaterPair R1 = clonePair(current);
-    {
-      CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
-      coupledTimeStepper(finalTime, factory, parset, myGlobalFriction, R1.first,
-                         R1.second, computeExternalForces);
-      iterationWriter << "R1 ";
-      auto const iterations =
-          coupledTimeStepper.step(relativeTime, relativeTau);
-      iterationWriter << iterations << std::endl;
-    }
-
-    UpdaterPair R2;
-
-    while (relativeTime < 1.0 - 1e-10) {
-      bool didCoarsen = false;
-
-      while (true) {
-        R2 = clonePair(R1);
-
-        {
-          CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
-          coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
-                             R2.first, R2.second, computeExternalForces);
-          iterationWriter << "R2 ";
-          auto const iterations =
-              coupledTimeStepper.step(relativeTime + relativeTau, relativeTau);
-          iterationWriter << iterations << " " << std::flush;
-        }
-
-        UpdaterPair C = clonePair(current);
-
-        {
-          CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
-          coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
-                             C.first, C.second, computeExternalForces);
-          iterationWriter << "C ";
-          auto const iterations =
-              coupledTimeStepper.step(relativeTime, 2.0 * relativeTau);
-          iterationWriter << iterations << " " << std::flush;
-        }
-
-        if (!mustRefine(C, R2)) {
-          R2 = { nullptr, nullptr };
-          R1 = C;
-          relativeTau *= 2.0;
-          didCoarsen = true;
-        } else {
-          break;
-        }
-      }
-
-      if (!didCoarsen) {
-        while (true) {
-          UpdaterPair F2 = clonePair(current);
-          UpdaterPair F1;
-
-          {
-            CoupledTimeStepper<NonlinearFactory, MyStateUpdater, MyTimeStepper>
-            coupledTimeStepper(finalTime, factory, parset, myGlobalFriction,
-                               F2.first, F2.second, computeExternalForces);
-            iterationWriter << "F1 ";
-            auto const iterationsF1 =
-                coupledTimeStepper.step(relativeTime, relativeTau / 2.0);
-            iterationWriter << iterationsF1 << " " << std::flush;
-
-            F1 = clonePair(F2);
-
-            iterationWriter << "F2 ";
-            auto const iterationsF2 = coupledTimeStepper.step(
-                relativeTime + relativeTau / 2.0, relativeTau / 2.0);
-            iterationWriter << iterationsF2 << " " << std::flush;
-          }
-
-          if (!mustRefine(R1, F2)) {
-            break;
-          } else {
-            R1 = F1;
-            R2 = F2;
-            relativeTau /= 2.0;
-          }
-        }
-      }
-      iterationWriter << std::endl;
-
-      reportTimeStep(relativeTime, relativeTau);
-
-      current = R1;
-      R1 = R2;
-      relativeTime += relativeTau;
+    AdaptiveTimeStepper<NonlinearFactory, UpdaterPair> adaptiveTimeStepper(
+        factory, parset, myGlobalFriction, current, computeExternalForces,
+        mustRefine);
+    while (!adaptiveTimeStepper.reachedEnd()) {
+      adaptiveTimeStepper.advance();
+      reportTimeStep(adaptiveTimeStepper.getRelativeTime(),
+                     adaptiveTimeStepper.getRelativeTau());
 
       Vector u, ur, vr;
       ScalarVector alpha;
@@ -501,7 +405,6 @@ int main(int argc, char *argv[]) {
       timeStep++;
     }
     timeStepWriter.close();
-    iterationWriter.close();
     Python::stop();
   }
   catch (Dune::Exception &e) {