diff --git a/src/foam/foam.cc b/src/foam/foam.cc
index b606d65b7f0e6d07639bbef32b2ed82680073aa5..e905374835455d82bc0085573ff4d0e35fed0760 100644
--- a/src/foam/foam.cc
+++ b/src/foam/foam.cc
@@ -445,6 +445,42 @@ int main(int argc, char *argv[]) {
 
     const auto& stateEnergyNorms = contactNetwork.stateEnergyNorms();
 
+    auto const mustRefine = [&](Updaters &coarseUpdater,
+                                Updaters &fineUpdater) {
+
+        //return false;
+      //std::cout << "Step 1" << std::endl;
+
+      std::vector<ScalarVector> coarseAlpha;
+      coarseAlpha.resize(bodyCount);
+      coarseUpdater.state_->extractAlpha(coarseAlpha);
+
+      //print(coarseAlpha, "coarseAlpha:");
+
+      std::vector<ScalarVector> fineAlpha;
+      fineAlpha.resize(bodyCount);
+      fineUpdater.state_->extractAlpha(fineAlpha);
+
+      //print(fineAlpha, "fineAlpha:");
+
+      //std::cout << "Step 3" << std::endl;
+
+      ScalarVector::field_type energyNorm = 0;
+      for (size_t i=0; i<stateEnergyNorms.size(); i++) {
+          //std::cout << "for " << i << std::endl;
+
+          //std::cout << not stateEnergyNorms[i] << std::endl;
+
+          if (coarseAlpha[i].size()==0 || fineAlpha[i].size()==0)
+              continue;
+
+          energyNorm += stateEnergyNorms[i]->diff(fineAlpha[i], coarseAlpha[i]);
+      }
+      //std::cout << "energy norm: " << energyNorm << " tol: " << refinementTolerance <<  std::endl;
+      //std::cout << "must refine: " << (energyNorm > refinementTolerance) <<  std::endl;
+      return energyNorm > refinementTolerance;
+    };
+
     std::signal(SIGXCPU, handleSignal);
     std::signal(SIGINT, handleSignal);
     std::signal(SIGTERM, handleSignal);
@@ -477,12 +513,17 @@ int main(int argc, char *argv[]) {
         stepBase(parset, contactNetwork, totalDirichletNodes, globalFriction, frictionNodes,
                  externalForces, stateEnergyNorms);
 
-    UniformTimeStepper<NonlinearFactory, std::decay_t<decltype(contactNetwork)>, Updaters, std::decay_t<decltype(stateEnergyNorms)>>
+    const auto minTau = parset.get<double>("timeSteps.minTau");
+    AdaptiveTimeStepper<NonlinearFactory, std::decay_t<decltype(contactNetwork)>, Updaters, std::decay_t<decltype(stateEnergyNorms)>>
         timeStepper(stepBase, contactNetwork, current,
-                            programState.relativeTime, programState.relativeTau);
+                            programState.relativeTime, programState.relativeTau, minTau,
+                            mustRefine);
 
-    size_t timeSteps = std::round(parset.get<double>("timeSteps.timeSteps"));
+    /*UniformTimeStepper<NonlinearFactory, std::decay_t<decltype(contactNetwork)>, Updaters, std::decay_t<decltype(stateEnergyNorms)>>
+        timeStepper(stepBase, contactNetwork, current,
+                            programState.relativeTime, programState.relativeTau);*/
 
+    size_t timeSteps = std::round(parset.get<double>("timeSteps.timeSteps"));
     while (!timeStepper.reachedEnd()) {
       programState.timeStep++;