diff --git a/dune/solvers/iterationsteps/blockgssteps.hh b/dune/solvers/iterationsteps/blockgssteps.hh index 01d17c64456b2362cf1e14a9b6e470bc3adde5d0..bd68b4cbf33ae0f78a44f7e5b7d2e1e408a0c5d7 100644 --- a/dune/solvers/iterationsteps/blockgssteps.hh +++ b/dune/solvers/iterationsteps/blockgssteps.hh @@ -40,16 +40,24 @@ void linearStep(const M& m, V& x, const V& b, const BitVector* ignore, // Note: move-capture requires C++14 auto blockStep = [&, localSolver = std::move(localSolver) ](size_t i) { const auto& row_i = m[i]; + // Compute residual auto ri = b[i]; - for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt) - cIt->mmv(x[cIt.index()], ri); + using Block = typename M::block_type; + const Block* diag = nullptr; + for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt) { + size_t j = cIt.index(); + cIt->mmv(x[j], ri); + if (j == i) + diag = &*cIt; + } std::bitset<V::block_type::dimension> ignore_i(0); if (ignore != nullptr) ignore_i = (*ignore)[i]; + // Update iterate with correction - x[i] += localSolver(row_i[i], std::move(ri), ignore_i); + x[i] += localSolver(diag ? *diag : Block(0.0), std::move(ri), ignore_i); }; if (direction != Direction::BACKWARD)