diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc
index b3dc3dd820c0d2cac3c7cc304b95757998cd0c3b..ee8807ee7741cb02a1091ead2a652dafdf647cac 100644
--- a/src/sand-wedge.cc
+++ b/src/sand-wedge.cc
@@ -95,6 +95,91 @@ void initPython() {
   Python::run("sys.path.append('" datadir "')");
 }
 
+template <class Factory, class StateUpdater, class VelocityUpdater>
+class FixedPointIterator {
+  using ScalarVector = typename StateUpdater::ScalarVector;
+  using Vector = typename Factory::Vector;
+  using Matrix = typename Factory::Matrix;
+  using ConvexProblem = typename Factory::ConvexProblem;
+  using BlockProblem = typename Factory::BlockProblem;
+  using Nonlinearity = typename ConvexProblem::NonlinearityType;
+
+public:
+  FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset,
+                     std::shared_ptr<Nonlinearity> globalFriction)
+      : factory_(factory),
+        parset_(parset),
+        globalFriction_(globalFriction),
+        fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")),
+        fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")),
+        lambda_(parset.get<double>("v.fpi.lambda")),
+        velocityMaxIterations_(
+            parset.get<size_t>("v.solver.maximumIterations")),
+        velocityTolerance_(parset.get<double>("v.solver.tolerance")),
+        verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")) {}
+
+  void run(std::shared_ptr<StateUpdater> stateUpdater,
+           std::shared_ptr<VelocityUpdater> velocityUpdater,
+           Matrix const &velocityMatrix, Norm<Vector> const &velocityMatrixNorm,
+           Vector const &velocityRHS, Vector &velocityIterate) {
+    auto multigridStep = factory_.getSolver();
+
+    LoopSolver<Vector> velocityProblemSolver(
+        multigridStep, velocityMaxIterations_, velocityTolerance_,
+        &velocityMatrixNorm, verbosity_, false); // absolute error
+
+    Vector previousVelocityIterate;
+
+    for (size_t fixedPointIteration = 1;
+         fixedPointIteration <= fixedPointMaxIterations_;
+         ++fixedPointIteration) {
+      Vector v_m;
+      velocityUpdater->extractOldVelocity(v_m);
+      v_m *= 1.0 - lambda_;
+      Arithmetic::addProduct(v_m, lambda_, velocityIterate);
+
+      // solve a state problem
+      stateUpdater->solve(v_m);
+      ScalarVector alpha;
+      stateUpdater->extractAlpha(alpha);
+
+      // solve a velocity problem
+      globalFriction_->updateAlpha(alpha);
+      ConvexProblem convexProblem(1.0, velocityMatrix, *globalFriction_,
+                                  velocityRHS, velocityIterate);
+      BlockProblem velocityProblem(parset_, convexProblem);
+      multigridStep->setProblem(velocityIterate, velocityProblem);
+      velocityProblemSolver.preprocess();
+      velocityProblemSolver.solve();
+
+      if (fixedPointIteration > 1) {
+        auto const velocityCorrection =
+            velocityMatrixNorm.diff(previousVelocityIterate, velocityIterate);
+        if (velocityCorrection < fixedPointTolerance_)
+          break;
+      }
+      if (fixedPointIteration == fixedPointMaxIterations_)
+        DUNE_THROW(Dune::Exception, "FPI failed to converge");
+
+      previousVelocityIterate = velocityIterate;
+    }
+    velocityUpdater->postProcess(velocityIterate);
+    velocityUpdater->postProcessRelativeQuantities();
+  }
+
+private:
+  Factory &factory_;
+  Dune::ParameterTree const &parset_;
+  std::shared_ptr<Nonlinearity> globalFriction_;
+
+  size_t fixedPointMaxIterations_;
+  double fixedPointTolerance_;
+  double lambda_;
+  size_t velocityMaxIterations_;
+  double velocityTolerance_;
+  Solver::VerbosityMode verbosity_;
+};
+
 int main(int argc, char *argv[]) {
   try {
     Dune::ParameterTree parset;
@@ -338,7 +423,6 @@ int main(int argc, char *argv[]) {
         Grid>;
     NonlinearFactory factory(parset.sub("solver.tnnmg"), refinements, *grid,
                              dirichletNodes);
-    auto multigridStep = factory.getSolver();
 
     auto velocityUpdater = initTimeStepper(
         parset.get<Config::scheme>("timeSteps.scheme"),
@@ -350,19 +434,8 @@ int main(int argc, char *argv[]) {
         parset.get<double>("boundary.friction.L"),
         parset.get<double>("boundary.friction.V0"));
 
-    Vector v_m(leafVertexCount);
-    ScalarVector alpha(leafVertexCount);
-
-    auto const timeSteps = parset.get<size_t>("timeSteps.number"),
-               maximumStateFPI = parset.get<size_t>("v.fpi.maximumIterations"),
-               maximumIterations =
-                   parset.get<size_t>("v.solver.maximumIterations");
-    auto const tau = parset.get<double>("problem.finalTime") / timeSteps,
-               tolerance = parset.get<double>("v.solver.tolerance"),
-               fixedPointTolerance = parset.get<double>("v.fpi.tolerance");
-    auto const verbosity =
-        parset.get<Solver::VerbosityMode>("v.solver.verbosity");
-    auto const lambda = parset.get<double>("v.fpi.lambda");
+    auto const timeSteps = parset.get<size_t>("timeSteps.number");
+    auto const tau = parset.get<double>("problem.finalTime") / timeSteps;
     for (size_t timeStep = 1; timeStep <= timeSteps; ++timeStep) {
       stateUpdater->nextTimeStep();
       velocityUpdater->nextTimeStep();
@@ -379,60 +452,20 @@ int main(int argc, char *argv[]) {
       velocityUpdater->setup(ell, tau, relativeTime, velocityRHS,
                              velocityIterate, velocityMatrix);
 
-      LoopSolver<Vector> velocityProblemSolver(multigridStep, maximumIterations,
-                                               tolerance, &AMNorm, verbosity,
-                                               false); // absolute error
-
-      size_t iterationCounter;
-      auto solveVelocityProblem = [&](Vector &_velocityIterate,
-                                      ScalarVector const &_alpha) {
-        myGlobalFriction->updateAlpha(_alpha);
-
-        // NIT: Do we really need to pass u here?
-        typename NonlinearFactory::ConvexProblem convexProblem(
-            1.0, velocityMatrix, *myGlobalFriction, velocityRHS,
-            _velocityIterate);
-        typename NonlinearFactory::BlockProblem velocityProblem(parset,
-                                                                convexProblem);
-        multigridStep->setProblem(_velocityIterate, velocityProblem);
-
-        velocityProblemSolver.preprocess();
-        velocityProblemSolver.solve();
-        iterationCounter = velocityProblemSolver.getResult().iterations;
-      };
-
-      Vector v_saved;
-      for (size_t stateFPI = 1; stateFPI <= maximumStateFPI; ++stateFPI) {
-        velocityUpdater->extractOldVelocity(v_m);
-        v_m *= 1.0 - lambda;
-        Arithmetic::addProduct(v_m, lambda, velocityIterate);
-
-        stateUpdater->solve(v_m);
-        stateUpdater->extractAlpha(alpha);
-
-        solveVelocityProblem(velocityIterate, alpha);
-
-        if (stateFPI > 1) {
-          double const velocityCorrection =
-              AMNorm.diff(v_saved, velocityIterate);
-          if (velocityCorrection < fixedPointTolerance)
-            break;
-        }
-        if (stateFPI == maximumStateFPI)
-          DUNE_THROW(Dune::Exception, "FPI failed to converge");
-
-        v_saved = velocityIterate;
-      }
-      velocityUpdater->postProcess(velocityIterate);
+      FixedPointIterator<NonlinearFactory, StateUpdater<ScalarVector, Vector>,
+                         TimeSteppingScheme<Vector, Matrix, Function, dims>>
+      fixedPointIterator(factory, parset, myGlobalFriction);
+      fixedPointIterator.run(stateUpdater, velocityUpdater, velocityMatrix,
+                             AMNorm, velocityRHS, velocityIterate);
 
       Vector u, ur, vr;
+      ScalarVector alpha;
       velocityUpdater->extractDisplacement(u);
-      velocityUpdater->postProcessRelativeQuantities();
       velocityUpdater->extractRelativeDisplacement(ur);
       velocityUpdater->extractRelativeVelocity(vr);
+      stateUpdater->extractAlpha(alpha);
 
       report(ur, vr, alpha);
-
       {
         BasisGridFunction<typename MyAssembler::VertexBasis, Vector>
         relativeVelocity(myAssembler.vertexBasis, vr);
diff --git a/src/state/stateupdater.hh b/src/state/stateupdater.hh
index 0ebf4774c11bf7825954ec4006a0b870645e7241..697662d32e8bf877441fed9085c7ccaa0bba087c 100644
--- a/src/state/stateupdater.hh
+++ b/src/state/stateupdater.hh
@@ -1,8 +1,10 @@
 #ifndef STATE_UPDATER_HH
 #define STATE_UPDATER_HH
 
-template <class ScalarVector, class Vector> class StateUpdater {
+template <class ScalarVectorTEMPLATE, class Vector> class StateUpdater {
 public:
+  using ScalarVector = ScalarVectorTEMPLATE;
+
   void virtual nextTimeStep() = 0;
   void virtual setup(double _tau) = 0;
   void virtual solve(Vector const &velocity_field) = 0;