From efb0355ea579dd2a675ff3e00a5ed4d283781f77 Mon Sep 17 00:00:00 2001
From: podlesny <podlesny@zedat.fu-berlin.de>
Date: Sun, 17 Jan 2021 23:53:11 +0100
Subject: [PATCH] multi threading steps in adaptivetimestepper

---
 dune/tectonic/time-stepping/CMakeLists.txt    |   4 +
 .../time-stepping/adaptivetimestepper.cc      | 475 ++++++++----------
 .../time-stepping/adaptivetimestepper.hh      |  52 +-
 .../time-stepping/adaptivetimestepper_tmpl.cc |   8 +-
 dune/tectonic/time-stepping/step.hh           | 273 ++++++++++
 dune/tectonic/time-stepping/stepbase.hh       |  49 ++
 src/multi-body-problem/multi-body-problem.cc  |   9 +-
 src/strikeslip/strikeslip-2D.cfg              |   4 +-
 src/strikeslip/strikeslip.cc                  | 125 +----
 src/strikeslip/strikeslip.cfg                 |   8 +-
 10 files changed, 597 insertions(+), 410 deletions(-)
 create mode 100644 dune/tectonic/time-stepping/step.hh
 create mode 100644 dune/tectonic/time-stepping/stepbase.hh

diff --git a/dune/tectonic/time-stepping/CMakeLists.txt b/dune/tectonic/time-stepping/CMakeLists.txt
index f316e155..ce276d15 100644
--- a/dune/tectonic/time-stepping/CMakeLists.txt
+++ b/dune/tectonic/time-stepping/CMakeLists.txt
@@ -10,6 +10,8 @@ add_custom_target(tectonic_dune_time-stepping SOURCES
   rate.cc
   state.hh
   state.cc
+  step.hh
+  stepbase.hh
   updaters.hh
 )
 
@@ -19,5 +21,7 @@ install(FILES
   coupledtimestepper.hh
   rate.hh
   state.hh
+  step.hh
+  stepbase.hh
   updaters.hh
 DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dune/tectonic)
diff --git a/dune/tectonic/time-stepping/adaptivetimestepper.cc b/dune/tectonic/time-stepping/adaptivetimestepper.cc
index f11a7673..52997a82 100644
--- a/dune/tectonic/time-stepping/adaptivetimestepper.cc
+++ b/dune/tectonic/time-stepping/adaptivetimestepper.cc
@@ -2,7 +2,9 @@
 #include "config.h"
 #endif
 
+#include <future>
 #include <thread>
+#include <chrono>
 
 #include <dune/solvers/norms/energynorm.hh>
 #include <dune/solvers/iterationsteps/multigridstep.hh>
@@ -14,80 +16,10 @@
 #include <dune/tectonic/utils/reductionfactors.hh>
 
 #include "adaptivetimestepper.hh"
+#include "step.hh"
 
 const unsigned int N_THREADS = std::thread::hardware_concurrency();
 
-template<class IterationStepType, class NormType, class ReductionFactorContainer>
-Dune::Solvers::Criterion reductionFactorCriterion(
-      IterationStepType& iterationStep,
-      const NormType& norm,
-      ReductionFactorContainer& reductionFactors)
-{
-  double normOfOldCorrection = 1;
-  auto lastIterate = std::make_shared<typename IterationStepType::Vector>(*iterationStep.getIterate());
-
-  return Dune::Solvers::Criterion(
-      [&, lastIterate, normOfOldCorrection] () mutable {
-        double normOfCorrection = norm.diff(*lastIterate, *iterationStep.getIterate());
-        double convRate = (normOfOldCorrection > 0) ? normOfCorrection / normOfOldCorrection : 0.0;
-
-        if (convRate>1.0)
-            std::cout << "Solver convergence rate of " << convRate << std::endl;
-
-        normOfOldCorrection = normOfCorrection;
-        *lastIterate = *iterationStep.getIterate();
-
-        reductionFactors.push_back(convRate);
-        return std::make_tuple(convRate < 0, Dune::formatString(" % '.5f", convRate));
-      },
-      " reductionFactor");
-}
-
-
-template<class IterationStepType, class Functional, class ReductionFactorContainer>
-Dune::Solvers::Criterion energyCriterion(
-      const IterationStepType& iterationStep,
-      const Functional& f,
-      ReductionFactorContainer& reductionFactors)
-{
-  double normOfOldCorrection = 1;
-  auto lastIterate = std::make_shared<typename IterationStepType::Vector>(*iterationStep.getIterate());
-
-  return Dune::Solvers::Criterion(
-      [&, lastIterate, normOfOldCorrection] () mutable {
-        double normOfCorrection = std::abs(f(*lastIterate) - f(*iterationStep.getIterate())); //norm.diff(*lastIterate, *iterationStep.getIterate());
-
-        double convRate = (normOfOldCorrection != 0.0) ? 1.0 - (normOfCorrection / normOfOldCorrection) : 0.0;
-
-        if (convRate>1.0)
-            std::cout << "Solver convergence rate of " << convRate << std::endl;
-
-        normOfOldCorrection = normOfCorrection;
-        *lastIterate = *iterationStep.getIterate();
-
-        reductionFactors.push_back(convRate);
-        return std::make_tuple(convRate < 0, Dune::formatString(" % '.5f", convRate));
-      },
-      " reductionFactor");
-}
-
-template <class ReductionFactorContainer>
-void updateReductionFactors(ReductionFactorContainer& reductionFactors) {
-    const size_t s = reductionFactors.size();
-
-    //print(reductionFactors, "reduction factors: ");
-
-    if (s>allReductionFactors.size()) {
-        allReductionFactors.resize(s);
-    }
-
-    for (size_t i=0; i<reductionFactors.size(); i++) {
-        allReductionFactors[i].push_back(reductionFactors[i]);
-    }
-
-    reductionFactors.clear();
-}
-
 void IterationRegister::registerCount(FixedPointIterationCounter count) {
   totalCount += count;
 }
@@ -101,239 +33,284 @@ void IterationRegister::reset() {
   finalCount = FixedPointIterationCounter();
 }
 
+/*
+ * Implementation: AdaptiveTimeStepper
+ */
 template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
 AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::AdaptiveTimeStepper(
-        Dune::ParameterTree const &parset,
+        const StepBase& stepBase,
         ContactNetwork& contactNetwork,
-        const IgnoreVector& ignoreNodes,
-        GlobalFriction& globalFriction,
-        const std::vector<const BitVector*>& bodywiseNonmortarBoundaries,
         Updaters &current,
         double relativeTime,
         double relativeTau,
-        ExternalForces& externalForces,
-        const ErrorNorms& errorNorms,
         std::function<bool(Updaters &, Updaters &)> mustRefine)
     : relativeTime_(relativeTime),
       relativeTau_(relativeTau),
-      finalTime_(parset.get<double>("problem.finalTime")),
-      parset_(parset),
+      stepBase_(stepBase),
       contactNetwork_(contactNetwork),
-      ignoreNodes_(ignoreNodes),
-      globalFriction_(globalFriction),
-      bodywiseNonmortarBoundaries_(bodywiseNonmortarBoundaries),
       current_(current),
       R1_(),
-      externalForces_(externalForces),
-      mustRefine_(mustRefine),
-      errorNorms_(errorNorms) {}
+      mustRefine_(mustRefine) {}
 
 template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
 bool AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::reachedEnd() {
   return relativeTime_ >= 1.0;
 }
 
+// compute C and R2 in parallel
+// returns number of coarsenings done
 template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
-auto AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::makeLinearSolver() const {
-    // make linear solver for linear correction in TNNMGStep
-    using Norm =  EnergyNorm<Matrix, Vector>;
-    using Preconditioner = MultilevelPatchPreconditioner<ContactNetwork, Matrix, Vector>;
-    using LinearSolver = typename Dune::Solvers::LoopSolver<Vector>;
+int AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::coarsen() {
+    using Step = Step<Factory, ContactNetwork, Updaters, ErrorNorms>;
 
-    const auto& preconditionerParset = parset_.sub("solver.tnnmg.linear.preconditioner");
+    int coarseCount = 1; // one coarsening step was already done in determineStrategy()
 
-    Dune::BitSetVector<1> activeLevels(contactNetwork_.nLevels(), true);
-    Preconditioner preconditioner(preconditionerParset, contactNetwork_, activeLevels);
-    preconditioner.setPatchDepth(preconditionerParset.template get<size_t>("patchDepth"));
-    preconditioner.build();
+    UpdatersWithCount C;
+    UpdatersWithCount R2;
 
-    auto cgStep = std::make_shared<Dune::Solvers::CGStep<Matrix, Vector>>();
-    cgStep->setPreconditioner(preconditioner);
+    const auto& currentNBodyAssembler = contactNetwork_.nBodyAssembler();
 
-    Norm norm(*cgStep);
+    while (relativeTime_ + relativeTau_ <= 1.0) {
+      std::cout << "tau: " << relativeTau_ << std::endl;
 
-    return std::make_shared<LinearSolver>(cgStep, parset_.template get<int>("solver.tnnmg.main.multi"), parset_.template get<double>("solver.tnnmg.linear.tolerance"), norm, Solver::QUIET);
-}
+      setDeformation(current_);
+      auto C_Step = Step(stepBase_, current_, currentNBodyAssembler, relativeTime_, 2 * relativeTau_, iterationRegister_);
+      C_Step.run(Step::Mode::newThread); //newThread
 
-template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
-IterationRegister AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::advance() {
-  /*
-    |     C     | We check here if making the step R1 of size tau is a
-    |  R1 | R2  | good idea. To check if we can coarsen, we compare
-    |F1|F2|     | the result of (R1+R2) with C, i.e. two steps of size
-                  tau with one of size 2*tau. To check if we need to
-    refine, we compare the result of (F1+F2) with R1, i.e. two steps
-    of size tau/2 with one of size tau. The method makes multiple
-    coarsening/refining attempts, with coarsening coming first. */
+      //updateReductionFactors(reductionFactors);
+      std::cout << "AdaptiveTimeStepper C computed!" << std::endl << std::endl;
 
-  //std::cout << "AdaptiveTimeStepper::advance()" << std::endl;
+      /*using ScalarVector = typename Updaters::StateUpdater::ScalarVector;
+      std::vector<ScalarVector> cAlpha(contactNetwork_.nBodies());
+      C.updaters.state_->extractAlpha(cAlpha);
+      print(cAlpha, "cAlpha: ");*/
 
-  // patch preconditioner only needs to be computed once per advance()
-  // make linear solver for linear correction in TNNMGStep
-  using Norm =  EnergyNorm<Matrix, Vector>;
-  using Preconditioner = MultilevelPatchPreconditioner<ContactNetwork, Matrix, Vector>;
-  using LinearSolver = typename Dune::Solvers::LoopSolver<Vector>;
+      setDeformation(R1_.updaters);
+      //auto R2_linearSolver = makeLinearSolver();
+      auto&& nBodyAssembler = step(currentNBodyAssembler);
+      auto R2_Step = Step(stepBase_, R1_.updaters, nBodyAssembler, relativeTime_ + relativeTau_, relativeTau_, iterationRegister_);
+      R2_Step.run(Step::Mode::newThread); //newThread
 
-  /*const auto& preconditionerParset = parset_.sub("solver.tnnmg.preconditioner");
+      //updateReductionFactors(reductionFactors);
+      std::cout << "AdaptiveTimeStepper R2 computed!" << std::endl << std::endl;
 
-  Dune::BitSetVector<1> activeLevels(contactNetwork_.nLevels(), true);
-  Preconditioner preconditioner(preconditionerParset, contactNetwork_, activeLevels);
-  preconditioner.setPatchDepth(preconditionerParset.template get<size_t>("patchDepth"));
-  preconditioner.build();
+      C = C_Step.get();
+      R2 = R2_Step.get();
 
-  auto cgStep = std::make_shared<Dune::Solvers::CGStep<Matrix, Vector>>();
-  cgStep->setPreconditioner(preconditioner);
+      /*std::vector<ScalarVector> rAlpha(contactNetwork_.nBodies());
+      R2.updaters.state_->extractAlpha(rAlpha);
+      print(rAlpha, "rAlpha: ");*/
 
-  Norm norm(*cgStep);
+      if (mustRefine_(C.updaters, R2.updaters))
+        break;
 
-  auto linearSolver = std::make_shared<LinearSolver>(cgStep, parset_.template get<int>("solver.tnnmg.main.multi"), parset_.template get<double>("solver.tnnmg.preconditioner.basesolver.tolerance"), norm, Solver::QUIET);
-  */
+      relativeTau_ *= 2;
+      R1_ = C;
 
-  // set multigrid solver
-  auto smoother = TruncatedBlockGSStep<Matrix, Vector>();
+      coarseCount++;
+    }
 
-  using TransferOperator = NBodyContactTransfer<ContactNetwork, Vector>;
-  using TransferOperators = std::vector<std::shared_ptr<TransferOperator>>;
+    current_ = R1_.updaters;
+    R1_ = R2;
 
-  TransferOperators transfer(contactNetwork_.nLevels()-1);
-  for (size_t i=0; i<transfer.size(); i++) {
-      transfer[i] = std::make_shared<TransferOperator>();
-      transfer[i]->setup(contactNetwork_, i, i+1);
-  }
+    return coarseCount;
+}
 
-  // Remove any recompute filed so that initially the full transferoperator is assembled
-  for (size_t i=0; i<transfer.size(); i++)
-      std::dynamic_pointer_cast<TruncatedMGTransfer<Vector> >(transfer[i])->setRecomputeBitField(nullptr);
+// compute F1 and F2 sequentially
+// returns number of refinements done
+template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
+int AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::refine(UpdatersWithCount& R2) {
+    using Step = Step<Factory, ContactNetwork, Updaters, ErrorNorms>;
 
-  auto linearMultigridStep = std::make_shared<Dune::Solvers::MultigridStep<Matrix, Vector> >();
-  linearMultigridStep->setMGType(1, 3, 3);
-  linearMultigridStep->setSmoother(smoother);
-  linearMultigridStep->setTransferOperators(transfer);
+    int refineCount = 1; // one refinement step was already done in determineStrategy()
 
-  Norm norm(*linearMultigridStep);
+    UpdatersWithCount F1;
+    UpdatersWithCount F2;
 
-  auto linearSolver = std::make_shared<LinearSolver>(linearMultigridStep, parset_.template get<int>("solver.tnnmg.main.multi"), parset_.template get<double>("solver.tnnmg.preconditioner.basesolver.tolerance"), norm, Solver::QUIET);
+    const auto& currentNBodyAssembler = contactNetwork_.nBodyAssembler();
 
-  Vector x;
-  x.resize(contactNetwork_.nBodyAssembler().totalHasObstacle_.size());
-  x = 0;
-  dynamic_cast<Dune::Solvers::IterationStep<Vector>*>(linearMultigridStep.get())->setProblem(x);
-  //dynamic_cast<Dune::Solvers::IterationStep<Vector>*>(cgStep.get())->setProblem(x);
+    while (true) {
+        setDeformation(current_);
+        //auto F1_linearSolver = makeLinearSolver();
+        auto F1_Step = Step(stepBase_, current_, currentNBodyAssembler, relativeTime_, relativeTau_ / 2.0, iterationRegister_);
+        F1_Step.run(Step::Mode::sameThread);
+        F1 = F1_Step.get();
+
+        //updateReductionFactors(reductionFactors);
+        std::cout << "AdaptiveTimeStepper F1 computed!" << std::endl << std::endl;
+
+        setDeformation(F1.updaters);
+        //auto F2_linearSolver = makeLinearSolver();
+        auto&& nBodyAssembler = step(currentNBodyAssembler);
+        auto F2_Step = Step(stepBase_, F1.updaters, nBodyAssembler, relativeTime_ + relativeTau_ / 2.0,
+                  relativeTau_ / 2.0, iterationRegister_);
+        F2_Step.run(Step::Mode::sameThread);
+        F2 = F2_Step.get();
+        //updateReductionFactors(reductionFactors);
+
+        if (!mustRefine_(R1_.updaters, F2.updaters)) {
+          std::cout << "Sufficiently refined!" << std::endl;
+          break;
+        }
+
+        relativeTau_ /= 2.0;
+        R1_ = F1;
+        R2 = F2;
+
+        refineCount++;
+    }
 
-  const auto& currentNBodyAssembler = contactNetwork_.nBodyAssembler();
+    current_ = R1_.updaters;
+    R1_ = R2;
 
-   std::vector<double> reductionFactors;
-   linearSolver->addCriterion(reductionFactorCriterion(*linearMultigridStep, norm, reductionFactors));
-   //linearSolver->addCriterion(reductionFactorCriterion(*cgStep, norm, reductionFactors));
+    return refineCount;
+}
 
-  auto refine = [&](UpdatersWithCount& F1, UpdatersWithCount& F2){
-      setDeformation(current_);
-      F1 = step(current_, currentNBodyAssembler, linearSolver, relativeTime_, relativeTau_ / 2.0);
-      updateReductionFactors(reductionFactors);
-      std::cout << "AdaptiveTimeStepper F1 computed!" << std::endl << std::endl;
 
-      setDeformation(F1.updaters);
-      auto&& nBodyAssembler = step(currentNBodyAssembler);
-      F2 = step(F1.updaters, nBodyAssembler, linearSolver, relativeTime_ + relativeTau_ / 2.0,
-                relativeTau_ / 2.0);
-      updateReductionFactors(reductionFactors);
-  };
+/*
+ * determines how to adapt time step, returns
+ * -1: coarsen
+ *  0: keep
+ *  1: refine
+ */
+template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
+int AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::determineStrategy(UpdatersWithCount& R2) {
+    using Step = Step<Factory, ContactNetwork, Updaters, ErrorNorms>;
 
-  auto coarsen = [&](UpdatersWithCount& C, UpdatersWithCount& R2){
-      setDeformation(current_);
-      C = step(current_, currentNBodyAssembler, linearSolver, relativeTime_, 2 * relativeTau_).get();
-      updateReductionFactors(reductionFactors);
-      std::cout << "AdaptiveTimeStepper C computed!" << std::endl << std::endl;
+    int strategy = 0;
 
-      /*using ScalarVector = typename Updaters::StateUpdater::ScalarVector;
-      std::vector<ScalarVector> cAlpha(contactNetwork_.nBodies());
-      C.updaters.state_->extractAlpha(cAlpha);
-      print(cAlpha, "cAlpha: ");*/
+    const auto& currentNBodyAssembler = contactNetwork_.nBodyAssembler();
 
-      setDeformation(R1_.updaters);
-      auto&& nBodyAssembler = step(currentNBodyAssembler);
-      R2 = step(R1_.updaters, nBodyAssembler, linearSolver, relativeTime_ + relativeTau_, relativeTau_).get();
-      updateReductionFactors(reductionFactors);
-      std::cout << "AdaptiveTimeStepper R2 computed!" << std::endl << std::endl;
-  };
+    UpdatersWithCount C;
+    UpdatersWithCount F1;
+    UpdatersWithCount F2;
 
-  auto canCoarsen = [&]{
-    std::cout << N_THREADS << " concurrent threads are supported." << std::endl;
+    //if (relativeTime_ + relativeTau_ > 1.0) {
+    //  return false;
+    //}
 
-    if (R1_.updaters == Updaters()) {
-      //setDeformation(current_);
-      R1_ = step(current_, currentNBodyAssembler, linearSolver, relativeTime_, relativeTau_);
-      updateReductionFactors(reductionFactors);
-    }
+    setDeformation(current_);
+    //auto C_linearSolver = makeLinearSolver();
+    auto C_Step = Step(stepBase_, current_, currentNBodyAssembler, relativeTime_, 2 * relativeTau_, iterationRegister_);
+    C_Step.run(Step::Mode::newThread); // newThread
+    //updateReductionFactors(reductionFactors);
+    std::cout << "AdaptiveTimeStepper C computed!" << std::endl << std::endl;
 
-    if (N_THREADS<3) {
+    setDeformation(R1_.updaters);
+    //auto R2_linearSolver = makeLinearSolver();
+    auto&& nBodyAssembler = step(currentNBodyAssembler);
+    auto R2_Step = Step(stepBase_, R1_.updaters, nBodyAssembler, relativeTime_ + relativeTau_, relativeTau_, iterationRegister_);
+    R2_Step.run(Step::Mode::newThread); //newThread
 
-    } else {
+    //updateReductionFactors(reductionFactors);
+    std::cout << "AdaptiveTimeStepper R2 computed!" << std::endl << std::endl;
 
+    if (N_THREADS < 3) {
+      C = C_Step.get();
+      R2 = R2_Step.get();
+
+      if (!mustRefine_(C.updaters, R2.updaters)) {
+          strategy = -1;
+      }
     }
 
-    return ;
-  };
+    if (strategy>-1) {
+        setDeformation(current_);
+        //auto F1_linearSolver = makeLinearSolver();
+        auto F1_Step = Step(stepBase_, current_, currentNBodyAssembler, relativeTime_, relativeTau_ / 2.0, iterationRegister_);
+        F1_Step.run(Step::Mode::newThread); //newThread
+        //updateReductionFactors(reductionFactors);
+        std::cout << "AdaptiveTimeStepper F1 computed!" << std::endl << std::endl;
+
+        if (N_THREADS > 2) {
+          C = C_Step.get();
+          R2 = R2_Step.get();
+
+          if (!mustRefine_(C.updaters, R2.updaters)) {
+              strategy = -1;
+          }
+        }
+
+        F1 = F1_Step.get();
+
+        if (strategy>-1) {
+            setDeformation(F1.updaters);
+            //auto F2_linearSolver = makeLinearSolver();
+            auto&& nBodyAssembler = step(currentNBodyAssembler);
+            auto F2_Step = Step(stepBase_, F1.updaters, nBodyAssembler, relativeTime_ + relativeTau_ / 2.0,
+                                relativeTau_ / 2.0, iterationRegister_);
+            F2_Step.run(Step::Mode::sameThread);
+            F2 = F2_Step.get();
+            //updateReductionFactors(reductionFactors);
+
+            if (mustRefine_(R1_.updaters, F2.updaters)) {
+                strategy = 1;
+            }
+        }
+    }
 
-  //std::cout << "AdaptiveTimeStepper Step 1" << std::endl;
+    switch (strategy) {
+        case -1:
+            relativeTau_ *= 2;
+            R1_ = C;
+            break;
+        case 0:
+            current_ = R1_.updaters;
+            R1_ = R2;
+            break;
+        case 1:
+            relativeTau_ /= 2.0;
+            R1_ = F1;
+            R2 = F2;
+            break;
+    }
 
-  size_t coarseningCount = 0;
-  size_t refineCount = 0;
+    return strategy;
+}
 
-  bool didCoarsen = false;
-  iterationRegister_.reset();
-  UpdatersWithCount R2;
-  UpdatersWithCount C;
+template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
+IterationRegister AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::advance() {
+  /*
+    |     C     | We check here if making the step R1 of size tau is a
+    |  R1 | R2  | good idea. To check if we can coarsen, we compare
+    |F1|F2|     | the result of (R1+R2) with C, i.e. two steps of size
+                  tau with one of size 2*tau. To check if we need to
+    refine, we compare the result of (F1+F2) with R1, i.e. two steps
+    of size tau/2 with one of size tau. The method makes multiple
+    coarsening/refining attempts, with coarsening coming first. */
 
-  while (relativeTime_ + relativeTau_ <= 1.0) {
-    std::cout << "tau: " << relativeTau_ << std::endl;
+  //std::cout << "AdaptiveTimeStepper::advance()" << std::endl;
 
-    coarsen(C, R2);
+  using Step = Step<Factory, ContactNetwork, Updaters, ErrorNorms>;
 
-    /*std::vector<ScalarVector> rAlpha(contactNetwork_.nBodies());
-    R2.updaters.state_->extractAlpha(rAlpha);
-    print(rAlpha, "rAlpha: ");*/
+  std::cout << N_THREADS << " concurrent threads are supported." << std::endl;
 
-    if (mustRefine_(C.updaters, R2.updaters))
-      break;
+  if (R1_.updaters == Updaters()) {
+      //setDeformation(current_);
+      //auto R1_linearSolver = makeLinearSolver();
+      auto R1_Step = Step(stepBase_, current_, contactNetwork_.nBodyAssembler(), relativeTime_, relativeTau_, iterationRegister_);
+      R1_Step.run(Step::Mode::sameThread);
+      R1_ = R1_Step.get();
+   }
 
-    didCoarsen = true;
-    relativeTau_ *= 2;
-    R1_ = C;
+  iterationRegister_.reset();
+  UpdatersWithCount R2;
+  int strat = determineStrategy(R2);
 
-    coarseningCount++;
+  // coarsen
+  if (strat<0) {
+      int coarseningCount = coarsen();
+      std::cout << " done with coarseningCount: " << coarseningCount << std::endl;
   }
 
-  UpdatersWithCount F1;
-  UpdatersWithCount F2;
-  if (!didCoarsen) {
-    while (true) {
-      refine(F1, F2);
-
-      if (!mustRefine_(R1_.updaters, F2.updaters)) {
-        std::cout << "Sufficiently refined!" << std::endl;
-        break;
-      }
-
-      relativeTau_ /= 2.0;
-      R1_ = F1;
-      R2 = F2;
-
-      refineCount++;
-    }
+  // refine
+  if (strat>0) {
+      int refineCount = refine(R2);
+      std::cout << " done with refineCount: " << refineCount << std::endl;
   }
 
-  std::cout << "AdaptiveTimeStepper::advance() ...";
-
   iterationRegister_.registerFinalCount(R1_.count);
   relativeTime_ += relativeTau_;
-  current_ = R1_.updaters;
-
-  //UpdatersWithCount emptyR1;
-  //R1_ = emptyR1;
-  R1_ = R2;
-
-  std::cout << " done with coarseningCount: " << coarseningCount << " and refineCount: " << refineCount << std::endl;
 
   return iterationRegister_;
 }
@@ -362,32 +339,4 @@ AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::step(const N
     return nBodyAssembler;
 }
 
-template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
-template <class LinearSolver>
-typename AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::UpdatersWithCount
-AdaptiveTimeStepper<Factory, ContactNetwork, Updaters, ErrorNorms>::step(
-    const Updaters& oldUpdaters, const NBodyAssembler& nBodyAssembler, std::shared_ptr<LinearSolver>& linearSolver, double rTime, double rTau) {
-
-  auto doStep = [&]{
-    UpdatersWithCount newUpdatersAndCount = {oldUpdaters.clone(), {}};
-
-    MyCoupledTimeStepper coupledTimeStepper(finalTime_, parset_, nBodyAssembler,
-                                          ignoreNodes_, globalFriction_, bodywiseNonmortarBoundaries_,
-                                          newUpdatersAndCount.updaters, errorNorms_, externalForces_);
-
-    newUpdatersAndCount.count = coupledTimeStepper.step(linearSolver, rTime, rTau);
-    iterationRegister_.registerCount(newUpdatersAndCount.count);
-
-    return newUpdatersAndCount;
-  };
-
-  return doStep();
-  /*using ScalarVector = typename Updaters::StateUpdater::ScalarVector;
-  std::vector<ScalarVector> alpha(contactNetwork_.nBodies());
-  newUpdatersAndCount.updaters.state_->extractAlpha(alpha);
-  print(alpha, "step alpha: ");
-  */
-
-}
-
 #include "adaptivetimestepper_tmpl.cc"
diff --git a/dune/tectonic/time-stepping/adaptivetimestepper.hh b/dune/tectonic/time-stepping/adaptivetimestepper.hh
index fa00188d..2d11dd9d 100644
--- a/dune/tectonic/time-stepping/adaptivetimestepper.hh
+++ b/dune/tectonic/time-stepping/adaptivetimestepper.hh
@@ -2,9 +2,14 @@
 #define SRC_TIME_STEPPING_ADAPTIVETIMESTEPPER_HH
 
 #include <fstream>
+#include <future>
+#include <thread>
 
 //#include "../multi-threading/task.hh"
+#include "../spatial-solving/contact/nbodycontacttransfer.hh"
+
 #include "coupledtimestepper.hh"
+#include "stepbase.hh"
 
 struct IterationRegister {
   void registerCount(FixedPointIterationCounter count);
@@ -16,20 +21,23 @@ struct IterationRegister {
   FixedPointIterationCounter finalCount;
 };
 
+template <class Updaters>
+struct UpdatersWithCount {
+  Updaters updaters;
+  FixedPointIterationCounter count;
+};
+
 template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
 class AdaptiveTimeStepper {
-  struct UpdatersWithCount {
-    Updaters updaters;
-    FixedPointIterationCounter count;
-  };
+  using UpdatersWithCount = UpdatersWithCount<Updaters>;
 
+  using StepBase = StepBase<Factory, ContactNetwork, Updaters, ErrorNorms>;
+
+  using NBodyAssembler = typename ContactNetwork::NBodyAssembler;
   using Vector = typename Factory::Vector;
   using Matrix = typename Factory::Matrix;
-  using IgnoreVector = typename Factory::BitVector;
-  //using ConvexProblem = typename Factory::ConvexProblem;
-  //using Nonlinearity = typename Factory::Nonlinearity;
 
-  using NBodyAssembler = typename ContactNetwork::NBodyAssembler;
+  using IgnoreVector = typename Factory::BitVector;
 
   using MyCoupledTimeStepper = CoupledTimeStepper<Factory, NBodyAssembler, Updaters, ErrorNorms>;
 
@@ -38,23 +46,15 @@ class AdaptiveTimeStepper {
   using ExternalForces = typename MyCoupledTimeStepper::ExternalForces;
 
 public:
-  AdaptiveTimeStepper(
-                      Dune::ParameterTree const &parset,
+  AdaptiveTimeStepper(const StepBase& stepBase,
                       ContactNetwork& contactNetwork,
-                      const IgnoreVector& ignoreNodes,
-                      GlobalFriction& globalFriction,
-                      const std::vector<const BitVector*>& bodywiseNonmortarBoundaries,
                       Updaters &current,
                       double relativeTime,
                       double relativeTau,
-                      ExternalForces& externalForces,
-                      const ErrorNorms& errorNorms,
                       std::function<bool(Updaters &, Updaters &)> mustRefine);
 
   bool reachedEnd();
 
-  auto makeLinearSolver() const;
-
   IterationRegister advance();
 
   double relativeTime_;
@@ -65,25 +65,19 @@ class AdaptiveTimeStepper {
 
   NBodyAssembler step(const NBodyAssembler& oldNBodyAssembler) const;
 
+  int refine(UpdatersWithCount& R2);
 
-  template <class LinearSolver>
-  UpdatersWithCount step(const Updaters& oldUpdaters, const NBodyAssembler& nBodyAssembler,
-                         std::shared_ptr<LinearSolver>& linearSolver,
-                         double rTime, double rTau);
+  int coarsen();
 
-  double finalTime_;
-  Dune::ParameterTree const &parset_;
-  ContactNetwork& contactNetwork_;
-  const IgnoreVector& ignoreNodes_;
+  int determineStrategy(UpdatersWithCount& R2);
 
-  GlobalFriction& globalFriction_;
-  const std::vector<const BitVector*>& bodywiseNonmortarBoundaries_;
+  const StepBase& stepBase_;
+  ContactNetwork& contactNetwork_;
 
   Updaters &current_;
   UpdatersWithCount R1_;
-  ExternalForces& externalForces_;
+
   std::function<bool(Updaters &, Updaters &)> mustRefine_;
-  const ErrorNorms& errorNorms_;
 
   IterationRegister iterationRegister_;
 };
diff --git a/dune/tectonic/time-stepping/adaptivetimestepper_tmpl.cc b/dune/tectonic/time-stepping/adaptivetimestepper_tmpl.cc
index 64155c96..a82a96df 100644
--- a/dune/tectonic/time-stepping/adaptivetimestepper_tmpl.cc
+++ b/dune/tectonic/time-stepping/adaptivetimestepper_tmpl.cc
@@ -2,6 +2,9 @@
 #error MY_DIM unset
 #endif
 
+#include <future>
+#include <thread>
+
 #include "../explicitgrid.hh"
 #include "../explicitvectors.hh"
 
@@ -37,6 +40,7 @@ using MySolverFactory = SolverFactory<MyFunctional, BitVector>;
 
 template class AdaptiveTimeStepper<MySolverFactory, MyContactNetwork, MyUpdaters, ErrorNorms>;
 
-template typename AdaptiveTimeStepper<MySolverFactory, MyContactNetwork, MyUpdaters, ErrorNorms>::UpdatersWithCount
+/*
+template std::packaged_task<typename AdaptiveTimeStepper<MySolverFactory, MyContactNetwork, MyUpdaters, ErrorNorms>::UpdatersWithCount()>
 AdaptiveTimeStepper<MySolverFactory, MyContactNetwork, MyUpdaters, ErrorNorms>::step<LinearSolver>(
-        const MyUpdaters&, const MyNBodyAssembler&, std::shared_ptr<LinearSolver>&, double, double);
+        const MyUpdaters&, const MyNBodyAssembler&, std::shared_ptr<LinearSolver>&, double, double); */
diff --git a/dune/tectonic/time-stepping/step.hh b/dune/tectonic/time-stepping/step.hh
new file mode 100644
index 00000000..ece7019c
--- /dev/null
+++ b/dune/tectonic/time-stepping/step.hh
@@ -0,0 +1,273 @@
+#ifndef DUNE_TECTONIC_TIME_STEPPING_STEP_HH
+#define DUNE_TECTONIC_TIME_STEPPING_STEP_HH
+
+
+#include <future>
+#include <thread>
+#include <chrono>
+#include <functional>
+
+#include <dune/solvers/norms/energynorm.hh>
+#include <dune/solvers/iterationsteps/multigridstep.hh>
+#include <dune/solvers/iterationsteps/cgstep.hh>
+#include <dune/solvers/solvers/loopsolver.hh>
+
+#include "../spatial-solving/preconditioners/multilevelpatchpreconditioner.hh"
+
+#include <dune/tectonic/utils/reductionfactors.hh>
+
+#include "stepbase.hh"
+#include "adaptivetimestepper.hh"
+
+template<class IterationStepType, class NormType, class ReductionFactorContainer>
+Dune::Solvers::Criterion reductionFactorCriterion(
+      IterationStepType& iterationStep,
+      const NormType& norm,
+      ReductionFactorContainer& reductionFactors)
+{
+  double normOfOldCorrection = 1;
+  auto lastIterate = std::make_shared<typename IterationStepType::Vector>(*iterationStep.getIterate());
+
+  return Dune::Solvers::Criterion(
+      [&, lastIterate, normOfOldCorrection] () mutable {
+        double normOfCorrection = norm.diff(*lastIterate, *iterationStep.getIterate());
+        double convRate = (normOfOldCorrection > 0) ? normOfCorrection / normOfOldCorrection : 0.0;
+
+        if (convRate>1.0)
+            std::cout << "Solver convergence rate of " << convRate << std::endl;
+
+        normOfOldCorrection = normOfCorrection;
+        *lastIterate = *iterationStep.getIterate();
+
+        reductionFactors.push_back(convRate);
+        return std::make_tuple(convRate < 0, Dune::formatString(" % '.5f", convRate));
+      },
+      " reductionFactor");
+}
+
+
+template<class IterationStepType, class Functional, class ReductionFactorContainer>
+Dune::Solvers::Criterion energyCriterion(
+      const IterationStepType& iterationStep,
+      const Functional& f,
+      ReductionFactorContainer& reductionFactors)
+{
+  double normOfOldCorrection = 1;
+  auto lastIterate = std::make_shared<typename IterationStepType::Vector>(*iterationStep.getIterate());
+
+  return Dune::Solvers::Criterion(
+      [&, lastIterate, normOfOldCorrection] () mutable {
+        double normOfCorrection = std::abs(f(*lastIterate) - f(*iterationStep.getIterate())); //norm.diff(*lastIterate, *iterationStep.getIterate());
+
+        double convRate = (normOfOldCorrection != 0.0) ? 1.0 - (normOfCorrection / normOfOldCorrection) : 0.0;
+
+        if (convRate>1.0)
+            std::cout << "Solver convergence rate of " << convRate << std::endl;
+
+        normOfOldCorrection = normOfCorrection;
+        *lastIterate = *iterationStep.getIterate();
+
+        reductionFactors.push_back(convRate);
+        return std::make_tuple(convRate < 0, Dune::formatString(" % '.5f", convRate));
+      },
+      " reductionFactor");
+}
+
+template <class ReductionFactorContainer>
+void updateReductionFactors(ReductionFactorContainer& reductionFactors) {
+    const size_t s = reductionFactors.size();
+
+    //print(reductionFactors, "reduction factors: ");
+
+    if (s>allReductionFactors.size()) {
+        allReductionFactors.resize(s);
+    }
+
+    for (size_t i=0; i<reductionFactors.size(); i++) {
+        allReductionFactors[i].push_back(reductionFactors[i]);
+    }
+
+    reductionFactors.clear();
+}
+
+
+template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
+class Step : protected StepBase<Factory, ContactNetwork, Updaters, ErrorNorms> {
+public:
+    enum Mode { sameThread, newThread };
+
+private:
+    using Base = StepBase<Factory, ContactNetwork, Updaters, ErrorNorms>;
+    using NBodyAssembler = typename ContactNetwork::NBodyAssembler;
+    using UpdatersWithCount = UpdatersWithCount<Updaters>;
+
+    const Updaters& oldUpdaters_;
+    const NBodyAssembler& nBodyAssembler_;
+
+    double relativeTime_;
+    double relativeTau_;
+
+    IterationRegister& iterationRegister_;
+
+    std::packaged_task<UpdatersWithCount()> task_;
+    std::future<UpdatersWithCount> future_;
+    std::thread thread_;
+
+    Mode mode_;
+
+public:
+    Step(const Base& stepFactory, const Updaters& oldUpdaters, const NBodyAssembler& nBodyAssembler, double rTime, double rTau, IterationRegister& iterationRegister) :
+        Base(stepFactory.parset_, stepFactory.contactNetwork_, stepFactory.ignoreNodes_, stepFactory.globalFriction_,
+             stepFactory.bodywiseNonmortarBoundaries_, stepFactory.externalForces_, stepFactory.errorNorms_),
+        oldUpdaters_(oldUpdaters),
+        nBodyAssembler_(nBodyAssembler),
+        relativeTime_(rTime),
+        relativeTau_(rTau),
+        iterationRegister_(iterationRegister) {}
+
+    UpdatersWithCount doStep() {
+        // make linear solver for linear correction in TNNMGStep
+        using Vector = typename Factory::Vector;
+        using Matrix = typename Factory::Matrix;
+
+        /* old, pre multi threading, was unused
+        using Norm =  EnergyNorm<Matrix, Vector>;
+        using Preconditioner = MultilevelPatchPreconditioner<ContactNetwork, Matrix, Vector>;
+        using LinearSolver = typename Dune::Solvers::LoopSolver<Vector>;
+
+        const auto& preconditionerParset = parset_.sub("solver.tnnmg.linear.preconditioner");
+
+        Dune::BitSetVector<1> activeLevels(contactNetwork_.nLevels(), true);
+        Preconditioner preconditioner(preconditionerParset, contactNetwork_, activeLevels);
+        preconditioner.setPatchDepth(preconditionerParset.template get<size_t>("patchDepth"));
+        preconditioner.build();
+
+        auto cgStep = std::make_shared<Dune::Solvers::CGStep<Matrix, Vector>>();
+        cgStep->setPreconditioner(preconditioner);
+
+        Norm norm(*cgStep);
+
+        return std::make_shared<LinearSolver>(cgStep, parset_.template get<int>("solver.tnnmg.main.multi"), parset_.template get<double>("solver.tnnmg.linear.tolerance"), norm, Solver::QUIET);
+        */
+
+        // patch preconditioner only needs to be computed once per advance()
+        // make linear solver for linear correction in TNNMGStep
+        using Norm =  EnergyNorm<Matrix, Vector>;
+        using Preconditioner = MultilevelPatchPreconditioner<ContactNetwork, Matrix, Vector>;
+        using LinearSolver = typename Dune::Solvers::LoopSolver<Vector>;
+
+        /*const auto& preconditionerParset = parset_.sub("solver.tnnmg.preconditioner");
+
+        Dune::BitSetVector<1> activeLevels(contactNetwork_.nLevels(), true);
+        Preconditioner preconditioner(preconditionerParset, contactNetwork_, activeLevels);
+        preconditioner.setPatchDepth(preconditionerParset.template get<size_t>("patchDepth"));
+        preconditioner.build();
+
+        auto cgStep = std::make_shared<Dune::Solvers::CGStep<Matrix, Vector>>();
+        cgStep->setPreconditioner(preconditioner);
+
+        Norm norm(*cgStep);
+
+        auto linearSolver = std::make_shared<LinearSolver>(cgStep, parset_.template get<int>("solver.tnnmg.main.multi"), parset_.template get<double>("solver.tnnmg.preconditioner.basesolver.tolerance"), norm, Solver::QUIET);
+        */
+
+        // set multigrid solver
+        auto smoother = TruncatedBlockGSStep<Matrix, Vector>();
+
+        // transfer operators need to be recomputed on change due to setDeformation()
+        using TransferOperator = NBodyContactTransfer<ContactNetwork, Vector>;
+        using TransferOperators = std::vector<std::shared_ptr<TransferOperator>>;
+
+        TransferOperators transfer(this->contactNetwork_.nLevels()-1);
+        for (size_t i=0; i<transfer.size(); i++) {
+            transfer[i] = std::make_shared<TransferOperator>();
+            transfer[i]->setup(this->contactNetwork_, i, i+1);
+        }
+
+        // Remove any recompute filed so that initially the full transferoperator is assembled
+        for (size_t i=0; i<transfer.size(); i++)
+            std::dynamic_pointer_cast<TruncatedMGTransfer<Vector> >(transfer[i])->setRecomputeBitField(nullptr);
+
+        auto linearMultigridStep = std::make_shared<Dune::Solvers::MultigridStep<Matrix, Vector> >();
+        linearMultigridStep->setMGType(1, 3, 3);
+        linearMultigridStep->setSmoother(smoother);
+        linearMultigridStep->setTransferOperators(transfer);
+
+        Norm norm(*linearMultigridStep);
+        auto linearSolver = std::make_shared<LinearSolver>(linearMultigridStep, this->parset_.template get<int>("solver.tnnmg.main.multi"), this->parset_.template get<double>("solver.tnnmg.preconditioner.basesolver.tolerance"), norm, Solver::QUIET);
+
+        Vector x;
+        x.resize(nBodyAssembler_.totalHasObstacle_.size());
+        x = 0;
+
+        linearSolver->getIterationStep().setProblem(x);
+      //dynamic_cast<Dune::Solvers::IterationStep<Vector>*>(linearMultigridStep.get())->setProblem(x);
+      //dynamic_cast<Dune::Solvers::IterationStep<Vector>*>(cgStep.get())->setProblem(x);
+
+      //std::vector<double> reductionFactors;
+      //linearSolver->addCriterion(reductionFactorCriterion(linearSolver->getIterationStep(), linearSolver->getErrorNorm(), reductionFactors));
+      //linearSolver->addCriterion(reductionFactorCriterion(*cgStep, norm, reductionFactors));
+
+      UpdatersWithCount newUpdatersAndCount = {oldUpdaters_.clone(), {}};
+
+      typename Base::MyCoupledTimeStepper coupledTimeStepper(
+        this->finalTime_, this->parset_, nBodyAssembler_,
+        this->ignoreNodes_, this->globalFriction_, this->bodywiseNonmortarBoundaries_,
+        newUpdatersAndCount.updaters, this->errorNorms_, this->externalForces_);
+
+      newUpdatersAndCount.count = coupledTimeStepper.step(linearSolver, relativeTime_, relativeTau_);
+      iterationRegister_.registerCount(newUpdatersAndCount.count);
+
+      //updateReductionFactors(reductionFactors);
+
+      return newUpdatersAndCount;
+    }
+
+   /* auto simple = [&]{
+        std::cout << "starting task... " << std::endl;
+        UpdatersWithCount newUpdatersAndCount = {oldUpdaters.clone(), {}};
+        std::this_thread::sleep_for(std::chrono::milliseconds(10000));
+        std::cout << "finishing task... " << std::endl;
+        return newUpdatersAndCount;
+    }*/
+
+    void run(Mode mode = sameThread) {
+        mode_ = mode;
+        task_ = std::packaged_task<UpdatersWithCount()>( [this]{ return doStep(); });
+        future_ = task_.get_future();
+
+        if (mode == Mode::sameThread) {
+            task_();
+        } else {
+            thread_ = std::thread(std::move(task_));
+        }
+    }
+
+    auto get() {
+        std::cout << "Step::get() called" << std::endl;
+        if (thread_.joinable()) {
+            thread_.join();
+            std::cout << "Thread joined " << std::endl;
+        }
+        return future_.get();
+    }
+
+    ~Step() {
+        if (thread_.joinable()) {
+            thread_.join();
+            std::cout << "Destructor:: Thread joined " << std::endl;
+        }
+    }
+};
+
+template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
+Step<Factory, ContactNetwork, Updaters, ErrorNorms>
+buildStep(const StepBase<Factory, ContactNetwork, Updaters, ErrorNorms>& base,
+          const Updaters& oldUpdaters,
+          const typename ContactNetwork::NBodyAssembler& nBodyAssembler,
+          double rTime, double rTau) {
+
+    return Step(base, oldUpdaters, nBodyAssembler, rTime, rTau);
+}
+
+#endif
diff --git a/dune/tectonic/time-stepping/stepbase.hh b/dune/tectonic/time-stepping/stepbase.hh
new file mode 100644
index 00000000..72df88c9
--- /dev/null
+++ b/dune/tectonic/time-stepping/stepbase.hh
@@ -0,0 +1,49 @@
+#ifndef DUNE_TECTONIC_TIME_STEPPING_STEPBASE_HH
+#define DUNE_TECTONIC_TIME_STEPPING_STEPBASE_HH
+
+#include "coupledtimestepper.hh"
+
+template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
+class StepBase {
+protected:
+  using NBodyAssembler = typename ContactNetwork::NBodyAssembler;
+  using IgnoreVector = typename Factory::BitVector;
+
+  using MyCoupledTimeStepper = CoupledTimeStepper<Factory, NBodyAssembler, Updaters, ErrorNorms>;
+
+  using GlobalFriction = typename MyCoupledTimeStepper::GlobalFriction;
+  using BitVector = typename MyCoupledTimeStepper::BitVector;
+  using ExternalForces = typename MyCoupledTimeStepper::ExternalForces;
+
+public:
+  StepBase(
+    Dune::ParameterTree const &parset,
+    ContactNetwork& contactNetwork,
+    const IgnoreVector& ignoreNodes,
+    GlobalFriction& globalFriction,
+    const std::vector<const BitVector*>& bodywiseNonmortarBoundaries,
+    ExternalForces& externalForces,
+    const ErrorNorms& errorNorms) :
+        parset_(parset),
+        finalTime_(parset_.get<double>("problem.finalTime")),
+        contactNetwork_(contactNetwork),
+        ignoreNodes_(ignoreNodes),
+        globalFriction_(globalFriction),
+        bodywiseNonmortarBoundaries_(bodywiseNonmortarBoundaries),
+        externalForces_(externalForces),
+        errorNorms_(errorNorms) {}
+
+  Dune::ParameterTree const &parset_;
+  double finalTime_;
+
+  ContactNetwork& contactNetwork_;
+  const IgnoreVector& ignoreNodes_;
+
+  GlobalFriction& globalFriction_;
+  const std::vector<const BitVector*>& bodywiseNonmortarBoundaries_;
+
+  ExternalForces& externalForces_;
+  const ErrorNorms& errorNorms_;
+};
+
+#endif
diff --git a/src/multi-body-problem/multi-body-problem.cc b/src/multi-body-problem/multi-body-problem.cc
index d7a0009a..84a580dd 100644
--- a/src/multi-body-problem/multi-body-problem.cc
+++ b/src/multi-body-problem/multi-body-problem.cc
@@ -74,6 +74,7 @@
 #include <dune/tectonic/time-stepping/adaptivetimestepper.hh>
 #include <dune/tectonic/time-stepping/rate.hh>
 #include <dune/tectonic/time-stepping/state.hh>
+#include <dune/tectonic/time-stepping/stepbase.hh>
 #include <dune/tectonic/time-stepping/updaters.hh>
 
 #include <dune/tectonic/utils/debugutils.hh>
@@ -341,10 +342,14 @@ int main(int argc, char *argv[]) {
     typename ContactNetwork::ExternalForces externalForces;
     contactNetwork.externalForces(externalForces);
 
+    StepBase<NonlinearFactory, std::decay_t<decltype(contactNetwork)>, Updaters, std::decay_t<decltype(stateEnergyNorms)>>
+        stepBase(parset, contactNetwork, totalDirichletNodes, globalFriction, frictionNodes,
+                 externalForces, stateEnergyNorms);
+
     AdaptiveTimeStepper<NonlinearFactory, std::decay_t<decltype(contactNetwork)>, Updaters, std::decay_t<decltype(stateEnergyNorms)>>
-        adaptiveTimeStepper(parset, contactNetwork, totalDirichletNodes, globalFriction, frictionNodes, current,
+        adaptiveTimeStepper(stepBase, contactNetwork, current,
                             programState.relativeTime, programState.relativeTau,
-                            externalForces, stateEnergyNorms, mustRefine);
+                            mustRefine);
 
     size_t timeSteps = parset.get<size_t>("timeSteps.timeSteps");
 
diff --git a/src/strikeslip/strikeslip-2D.cfg b/src/strikeslip/strikeslip-2D.cfg
index 5fb2c390..e360fd0b 100644
--- a/src/strikeslip/strikeslip-2D.cfg
+++ b/src/strikeslip/strikeslip-2D.cfg
@@ -1,9 +1,9 @@
 # -*- mode:conf -*-
 [body0]
-smallestDiameter = 0.05 # 2e-3 [m]
+smallestDiameter = 0.02 # 2e-3 [m]
 
 [body1]
-smallestDiameter = 0.05  # 2e-3 [m]
+smallestDiameter = 0.02  # 2e-3 [m]
 
 [timeSteps]
 refinementTolerance = 5e-6 # 1e-5
diff --git a/src/strikeslip/strikeslip.cc b/src/strikeslip/strikeslip.cc
index 4cdf3f81..eb17ed23 100644
--- a/src/strikeslip/strikeslip.cc
+++ b/src/strikeslip/strikeslip.cc
@@ -56,6 +56,8 @@
 
 #include <dune/tectonic/factories/strikeslipfactory.hh>
 
+
+#include <dune/tectonic/io/io-handler.hh>
 #include <dune/tectonic/io/hdf5-writer.hh>
 #include <dune/tectonic/io/hdf5/restart-io.hh>
 #include <dune/tectonic/io/vtk.hh>
@@ -167,29 +169,13 @@ int main(int argc, char *argv[]) {
 
     using MyProgramState = ProgramState<Vector, ScalarVector>;
     MyProgramState programState(nVertices);
-    auto const firstRestart = parset.get<size_t>("io.restarts.first");
-    auto const restartSpacing = parset.get<size_t>("io.restarts.spacing");
-    auto const writeRestarts = parset.get<bool>("io.restarts.write");
-    auto const writeData = parset.get<bool>("io.data.write");
-    bool const handleRestarts = writeRestarts or firstRestart > 0;
-
-    auto restartFile = handleRestarts
-                           ? std::make_unique<HDF5::File>(
-                                 "restarts.h5",
-                                 writeRestarts ? HDF5::Access::READWRITE
-                                               : HDF5::Access::READONLY)
-                           : nullptr;
-
-
-    auto restartIO = handleRestarts
-                         ? std::make_unique<RestartIO<MyProgramState>>(
-                               *restartFile, nVertices)
-                         : nullptr;
-
-    if (firstRestart > 0) // automatically adjusts the time and timestep
-      restartIO->read(firstRestart, programState);
-    else
-     programState.setupInitialConditions(parset, contactNetwork);
+
+    IOHandler<Assembler, ContactNetwork> ioHandler(parset.sub("io"), contactNetwork);
+
+    bool restartRead = ioHandler.read(programState);
+    if (!restartRead) {
+      programState.setupInitialConditions(parset, contactNetwork);
+    }
 
     auto& nBodyAssembler = contactNetwork.nBodyAssembler();
     for (size_t i=0; i<bodyCount; i++) {
@@ -206,90 +192,9 @@ int main(int argc, char *argv[]) {
     auto& globalFriction = contactNetwork.globalFriction();
     globalFriction.updateAlpha(programState.alpha);
 
-    using MyVertexBasis = typename Assembler::VertexBasis;
-    using MyCellBasis = typename Assembler::CellBasis;
-    std::vector<Vector> vertexCoordinates(bodyCount);
-    std::vector<const MyVertexBasis* > vertexBases(bodyCount);
-    std::vector<const MyCellBasis* > cellBases(bodyCount);
-
-    auto& wPatches = blocksFactory.weakPatches();
-    std::vector<std::vector<const ConvexPolyhedron<LocalVector>*>> weakPatches(bodyCount);
-
-
-    for (size_t i=0; i<bodyCount; i++) {
-      const auto& body = contactNetwork.body(i);
-      vertexBases[i] = &(body->assembler()->vertexBasis);
-      cellBases[i] = &(body->assembler()->cellBasis);
-
-      weakPatches[i].resize(1);
-      weakPatches[i][0] = wPatches[i].get();
-
-      auto& vertexCoords = vertexCoordinates[i];
-      vertexCoords.resize(nVertices[i]);
-
-      auto hostLeafView = body->grid()->hostGrid().leafGridView();
-      Dune::MultipleCodimMultipleGeomTypeMapper<
-          LeafGridView, Dune::MCMGVertexLayout> const vertexMapper(hostLeafView, Dune::mcmgVertexLayout());
-      for (auto &&v : vertices(hostLeafView))
-        vertexCoords[vertexMapper.index(v)] = geoToPoint(v.geometry());
-    }
-
-    typename ContactNetwork::BoundaryPatches frictionBoundaries;
-    contactNetwork.boundaryPatches("friction", frictionBoundaries);
-
-    std::vector<const typename ContactNetwork::BoundaryPatch*> frictionPatches;
-    contactNetwork.frictionPatches(frictionPatches);
-
-
     IterationRegister iterationCount;
 
-    auto const report = [&](bool initial = false) {
-        auto dataFile =
-            writeData ? std::make_unique<HDF5::File>("output.h5") : nullptr;
-
-        auto dataWriter =
-            writeData ? std::make_unique<
-                            HDF5Writer<MyProgramState, MyVertexBasis, DefLeafGridView>>(
-                            *dataFile, vertexCoordinates, vertexBases,
-                            frictionPatches, weakPatches)
-                      : nullptr;
-
-      if (writeData) {
-        dataWriter->reportSolution(programState, contactNetwork, globalFriction);
-        if (!initial)
-          dataWriter->reportIterations(programState, iterationCount);
-        dataFile->flush();
-      }
-
-      if (writeRestarts and !initial and
-          programState.timeStep % restartSpacing == 0) {
-        restartIO->write(programState);
-        restartFile->flush();
-      }
-
-      if (parset.get<bool>("io.printProgress"))
-        std::cout << "timeStep = " << std::setw(6) << programState.timeStep
-                  << ", time = " << std::setw(12) << programState.relativeTime
-                  << ", tau = " << std::setw(12) << programState.relativeTau
-                  << std::endl;
-
-      if (parset.get<bool>("io.vtk.write")) {
-        std::vector<ScalarVector> stress(bodyCount);
-
-        for (size_t i=0; i<bodyCount; i++) {
-          const auto& body = contactNetwork.body(i);
-          body->assembler()->assembleVonMisesStress(body->data()->getYoungModulus(),
-                                           body->data()->getPoissonRatio(),
-                                           programState.u[i], stress[i]);
-
-        }
-
-        const MyVTKWriter<MyVertexBasis, MyCellBasis> vtkWriter(cellBases, vertexBases, "../debug_print/vtk/");
-        vtkWriter.write(programState.timeStep, programState.u, programState.v,
-                        programState.alpha, stress);
-      }
-    };
-    report(true);
+    ioHandler.write(programState, contactNetwork, globalFriction, iterationCount, true);
 
     // -------------------
     // Set up TNNMG solver
@@ -428,10 +333,14 @@ int main(int argc, char *argv[]) {
     typename ContactNetwork::ExternalForces externalForces;
     contactNetwork.externalForces(externalForces);
 
+    StepBase<NonlinearFactory, std::decay_t<decltype(contactNetwork)>, Updaters, std::decay_t<decltype(stateEnergyNorms)>>
+        stepBase(parset, contactNetwork, totalDirichletNodes, globalFriction, frictionNodes,
+                 externalForces, stateEnergyNorms);
+
     AdaptiveTimeStepper<NonlinearFactory, std::decay_t<decltype(contactNetwork)>, Updaters, std::decay_t<decltype(stateEnergyNorms)>>
-        adaptiveTimeStepper(parset, contactNetwork, totalDirichletNodes, globalFriction, frictionNodes, current,
+        adaptiveTimeStepper(stepBase, contactNetwork, current,
                             programState.relativeTime, programState.relativeTau,
-                            externalForces, stateEnergyNorms, mustRefine);
+                            mustRefine);
 
     size_t timeSteps = parset.get<size_t>("timeSteps.timeSteps");
 
@@ -456,7 +365,7 @@ int main(int argc, char *argv[]) {
 
       contactNetwork.setDeformation(programState.u);
 
-      report();
+      ioHandler.write(programState, contactNetwork, globalFriction, iterationCount, false);
 
       if (programState.timeStep==timeSteps) {
         std::cout << "limit of timeSteps reached!" << std::endl;
diff --git a/src/strikeslip/strikeslip.cfg b/src/strikeslip/strikeslip.cfg
index 294e3ea3..bc5ed7cb 100644
--- a/src/strikeslip/strikeslip.cfg
+++ b/src/strikeslip/strikeslip.cfg
@@ -55,11 +55,11 @@ sigmaN          = 200.0      # [Pa]
 finalVelocity   = 1e-4     # [m/s]
 
 [io]
-data.write      = true
+data.write      = false
 printProgress   = true
 restarts.first  = 0
-restarts.spacing= 20
-restarts.write  = false #true
+restarts.spacing= 1
+restarts.write  = true #true
 vtk.write       = true
 
 [problem]
@@ -73,7 +73,7 @@ relativeTau = 2e-4 # 1e-6
 
 [timeSteps]
 scheme = newmark
-timeSteps = 5000
+timeSteps = 5
 
 [u0.solver]
 maximumIterations = 100
-- 
GitLab