diff --git a/dune/solvers/iterationsteps/blockgssteps.hh b/dune/solvers/iterationsteps/blockgssteps.hh
index 01d17c64456b2362cf1e14a9b6e470bc3adde5d0..bd68b4cbf33ae0f78a44f7e5b7d2e1e408a0c5d7 100644
--- a/dune/solvers/iterationsteps/blockgssteps.hh
+++ b/dune/solvers/iterationsteps/blockgssteps.hh
@@ -40,16 +40,24 @@ void linearStep(const M& m, V& x, const V& b, const BitVector* ignore,
   // Note: move-capture requires C++14
   auto blockStep = [&, localSolver = std::move(localSolver) ](size_t i) {
     const auto& row_i = m[i];
+
     // Compute residual
     auto ri = b[i];
-    for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt)
-      cIt->mmv(x[cIt.index()], ri);
+    using Block = typename M::block_type;
+    const Block* diag = nullptr;
+    for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt) {
+      size_t j = cIt.index();
+      cIt->mmv(x[j], ri);
+      if (j == i)
+        diag = &*cIt;
+    }
 
     std::bitset<V::block_type::dimension> ignore_i(0);
     if (ignore != nullptr)
       ignore_i = (*ignore)[i];
+
     // Update iterate with correction
-    x[i] += localSolver(row_i[i], std::move(ri), ignore_i);
+    x[i] += localSolver(diag ? *diag : Block(0.0), std::move(ri), ignore_i);
   };
 
   if (direction != Direction::BACKWARD)