diff --git a/dune/solvers/iterationsteps/blockgssteps.hh b/dune/solvers/iterationsteps/blockgssteps.hh index 01d17c64456b2362cf1e14a9b6e470bc3adde5d0..7ef8458553de6baa787a2eaea134fbbade336dc1 100644 --- a/dune/solvers/iterationsteps/blockgssteps.hh +++ b/dune/solvers/iterationsteps/blockgssteps.hh @@ -40,16 +40,24 @@ void linearStep(const M& m, V& x, const V& b, const BitVector* ignore, // Note: move-capture requires C++14 auto blockStep = [&, localSolver = std::move(localSolver) ](size_t i) { const auto& row_i = m[i]; - // Compute residual - auto ri = b[i]; - for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt) - cIt->mmv(x[cIt.index()], ri); std::bitset<V::block_type::dimension> ignore_i(0); if (ignore != nullptr) ignore_i = (*ignore)[i]; + + // Compute residual + auto ri = b[i]; + using Block = typename M::block_type; + const Block* diag = nullptr; + for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt) { + size_t j = cIt.index(); + cIt->mmv(x[j], ri); + if (j == i) + diag = &*cIt; + } + // Update iterate with correction - x[i] += localSolver(row_i[i], std::move(ri), ignore_i); + x[i] += localSolver(diag ? *diag : Block(0.0), std::move(ri), ignore_i); }; if (direction != Direction::BACKWARD) @@ -172,9 +180,10 @@ namespace LocalSolverFromLinearSolver { */ template <class LinearSolver> auto truncateSymmetrically(LinearSolver&& linearSolver) { - return [linearSolver = std::move(linearSolver) ]( - const auto& m, const auto& b, const auto& ignore) { - using Return = typename std::result_of<LinearSolver(decltype(m), decltype(b))>::type; + return [linearSolver = std::move(linearSolver)](const auto& m, const auto& b, + const auto& ignore) { + using Return = + typename std::result_of<LinearSolver(decltype(m), decltype(b))>::type; if (ignore.all()) return Return(0); @@ -270,7 +279,8 @@ auto cg(size_t maxIter = LinearSolvers::defaultCgMaxIter, auto cgSolver = std::bind( LinearSolvers::cg<std::decay_t<decltype(m)>, std::decay_t<decltype(b)>>, _1, _2, maxIter, tol); - auto cgTruncatedSolver = LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver); + auto cgTruncatedSolver = + LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver); auto cgTruncatedRegularizedSolver = LocalSolverRegularizer::diagRegularize(r, cgTruncatedSolver); return cgTruncatedRegularizedSolver(m, b, ignore);