diff --git a/dune/solvers/iterationsteps/blockgssteps.hh b/dune/solvers/iterationsteps/blockgssteps.hh index 0c4b56fad2fed6444806c972217d35ee187cea9c..7c771b42e43a21fba05b076b00a377b9b6e41bed 100644 --- a/dune/solvers/iterationsteps/blockgssteps.hh +++ b/dune/solvers/iterationsteps/blockgssteps.hh @@ -40,16 +40,23 @@ void linearStep(const M& m, V& x, const V& b, const BitVector* ignore, // Note: move-capture requires C++14 auto blockStep = [&, localSolver = std::move(localSolver) ](size_t i) { const auto& row_i = m[i]; + // Compute residual auto ri = b[i]; - for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt) - cIt->mmv(x[cIt.index()], ri); + using Block = typename M::block_type; + const Block* diag = nullptr; + for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt) { + size_t j = cIt.index(); + cIt->mmv(x[j], ri); + if (j == i) + diag = &*cIt; + } + + using Ignore = std::bitset<V::block_type::dimension>; - std::bitset<V::block_type::dimension> ignore_i(0); - if (ignore != nullptr) - ignore_i = (*ignore)[i]; // Update iterate with correction - x[i] += localSolver(row_i[i], std::move(ri), ignore_i); + x[i] += localSolver(diag != nullptr ? *diag : Block(0.0), std::move(ri), + ignore != nullptr ? (*ignore)[i] : Ignore(0)); }; if (direction != Direction::BACKWARD)