diff --git a/dune/solvers/iterationsteps/blockgssteps.hh b/dune/solvers/iterationsteps/blockgssteps.hh
index 01d17c64456b2362cf1e14a9b6e470bc3adde5d0..7ef8458553de6baa787a2eaea134fbbade336dc1 100644
--- a/dune/solvers/iterationsteps/blockgssteps.hh
+++ b/dune/solvers/iterationsteps/blockgssteps.hh
@@ -40,16 +40,24 @@ void linearStep(const M& m, V& x, const V& b, const BitVector* ignore,
   // Note: move-capture requires C++14
   auto blockStep = [&, localSolver = std::move(localSolver) ](size_t i) {
     const auto& row_i = m[i];
-    // Compute residual
-    auto ri = b[i];
-    for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt)
-      cIt->mmv(x[cIt.index()], ri);
 
     std::bitset<V::block_type::dimension> ignore_i(0);
     if (ignore != nullptr)
       ignore_i = (*ignore)[i];
+
+    // Compute residual
+    auto ri = b[i];
+    using Block = typename M::block_type;
+    const Block* diag = nullptr;
+    for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt) {
+      size_t j = cIt.index();
+      cIt->mmv(x[j], ri);
+      if (j == i)
+        diag = &*cIt;
+    }
+
     // Update iterate with correction
-    x[i] += localSolver(row_i[i], std::move(ri), ignore_i);
+    x[i] += localSolver(diag ? *diag : Block(0.0), std::move(ri), ignore_i);
   };
 
   if (direction != Direction::BACKWARD)
@@ -172,9 +180,10 @@ namespace LocalSolverFromLinearSolver {
  */
 template <class LinearSolver>
 auto truncateSymmetrically(LinearSolver&& linearSolver) {
-  return [linearSolver = std::move(linearSolver) ](
-      const auto& m, const auto& b, const auto& ignore) {
-    using Return = typename std::result_of<LinearSolver(decltype(m), decltype(b))>::type;
+  return [linearSolver = std::move(linearSolver)](const auto& m, const auto& b,
+                                                  const auto& ignore) {
+    using Return =
+        typename std::result_of<LinearSolver(decltype(m), decltype(b))>::type;
     if (ignore.all())
       return Return(0);
 
@@ -270,7 +279,8 @@ auto cg(size_t maxIter = LinearSolvers::defaultCgMaxIter,
     auto cgSolver = std::bind(
         LinearSolvers::cg<std::decay_t<decltype(m)>, std::decay_t<decltype(b)>>,
         _1, _2, maxIter, tol);
-    auto cgTruncatedSolver = LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver);
+    auto cgTruncatedSolver =
+        LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver);
     auto cgTruncatedRegularizedSolver =
         LocalSolverRegularizer::diagRegularize(r, cgTruncatedSolver);
     return cgTruncatedRegularizedSolver(m, b, ignore);