diff --git a/dune/tnnmg/iterationsteps/tnnmgstep.hh b/dune/tnnmg/iterationsteps/tnnmgstep.hh
index 2d320e61846112c537271a427da76df6a393e50c..ce592ece8f3e1839a5bfb698b4e11fffeb6825ed 100644
--- a/dune/tnnmg/iterationsteps/tnnmgstep.hh
+++ b/dune/tnnmg/iterationsteps/tnnmgstep.hh
@@ -10,6 +10,7 @@
 
 #include "dune/solvers/iterationsteps/iterationstep.hh"
 #include "dune/solvers/iterationsteps/lineariterationstep.hh"
+#include <dune/solvers/solvers/iterativesolver.hh>
 
 #ifndef USE_OLD_TNNMG
 
@@ -23,7 +24,6 @@ namespace TNNMG {
  * \tparam BV Bit-vector type for marking ignored components
  */
 template<class F, class BV, class Linearization,
-                                  class LinearSolver,
                                   class DefectProjection,
                                   class LineSearchSolver>
 class TNNMGStep :
@@ -36,6 +36,7 @@ public:
   using Vector = typename F::Vector;
   using BitVector = typename Base::BitVector;
   using Functional = F;
+  using LinearSolver = Solvers::IterativeSolver<typename Linearization::ConstrainedVector, Solvers::DefaultBitVector_t<Vector> >;
 
   /** \brief Constructor
    * \param linearSolver This is a callback used to solve the constrained linearized system
@@ -104,8 +105,30 @@ public:
 
     auto emptyIgnore = ignore;
     Solvers::resizeInitialize(emptyIgnore, constrainedCorrection_, false);
-    linearSolver_->step_->setIgnore(emptyIgnore);
-    linearSolver_->step_->setProblem(A, constrainedCorrection_, r);
+
+    // Hand the linear problem to the linear solver.
+    // Currently we only support IterativeSolvers.  The IterationStep member
+    // needs to be a LinearIterationStep, so we can give it the matrix.
+    using LinearIterationStepType = Solvers::LinearIterationStep<std::decay_t<decltype(A)>,
+                                                                 typename Linearization::ConstrainedVector,
+                                                                 decltype(emptyIgnore) >;
+
+    LinearIterationStepType* linearIterationStep;
+
+    auto iterativeSolver = std::dynamic_pointer_cast<Solvers::IterativeSolver<typename Linearization::ConstrainedVector> >(linearSolver_);
+    if (iterativeSolver)
+    {
+      iterativeSolver->iterationStep_->setIgnore(emptyIgnore);
+      linearIterationStep = dynamic_cast<LinearIterationStepType*>(iterativeSolver->iterationStep_);
+    } else
+      DUNE_THROW(Exception, "Linear solver has to be an IterativeSolver!");
+
+    if (linearIterationStep)
+      linearIterationStep->setProblem(A, constrainedCorrection_, r);
+    else
+      DUNE_THROW(Exception, "Linear solver does not accept matrices!");
+
+
     linearSolver_->preprocess();
     linearSolver_->solve();