Skip to content
Snippets Groups Projects
Commit 4ff9dbfb authored by Max Kahnt's avatar Max Kahnt
Browse files

Fix: Cope with omitted diagonal blocks in outer GS loop.

parent fa890457
No related branches found
No related tags found
No related merge requests found
Pipeline #
This commit is part of merge request !11. Comments created here will be created in the context of that merge request.
...@@ -40,16 +40,24 @@ void linearStep(const M& m, V& x, const V& b, const BitVector* ignore, ...@@ -40,16 +40,24 @@ void linearStep(const M& m, V& x, const V& b, const BitVector* ignore,
// Note: move-capture requires C++14 // Note: move-capture requires C++14
auto blockStep = [&, localSolver = std::move(localSolver) ](size_t i) { auto blockStep = [&, localSolver = std::move(localSolver) ](size_t i) {
const auto& row_i = m[i]; const auto& row_i = m[i];
// Compute residual
auto ri = b[i];
for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt)
cIt->mmv(x[cIt.index()], ri);
std::bitset<V::block_type::dimension> ignore_i(0); std::bitset<V::block_type::dimension> ignore_i(0);
if (ignore != nullptr) if (ignore != nullptr)
ignore_i = (*ignore)[i]; ignore_i = (*ignore)[i];
// Compute residual
auto ri = b[i];
using Block = typename M::block_type;
const Block* diag = nullptr;
for (auto cIt = row_i.begin(); cIt != row_i.end(); ++cIt) {
size_t j = cIt.index();
cIt->mmv(x[j], ri);
if (j == i)
diag = &*cIt;
}
// Update iterate with correction // Update iterate with correction
x[i] += localSolver(row_i[i], std::move(ri), ignore_i); x[i] += localSolver(diag ? *diag : Block(0.0), std::move(ri), ignore_i);
}; };
if (direction != Direction::BACKWARD) if (direction != Direction::BACKWARD)
...@@ -172,9 +180,10 @@ namespace LocalSolverFromLinearSolver { ...@@ -172,9 +180,10 @@ namespace LocalSolverFromLinearSolver {
*/ */
template <class LinearSolver> template <class LinearSolver>
auto truncateSymmetrically(LinearSolver&& linearSolver) { auto truncateSymmetrically(LinearSolver&& linearSolver) {
return [linearSolver = std::move(linearSolver) ]( return [linearSolver = std::move(linearSolver)](const auto& m, const auto& b,
const auto& m, const auto& b, const auto& ignore) { const auto& ignore) {
using Return = typename std::result_of<LinearSolver(decltype(m), decltype(b))>::type; using Return =
typename std::result_of<LinearSolver(decltype(m), decltype(b))>::type;
if (ignore.all()) if (ignore.all())
return Return(0); return Return(0);
...@@ -270,7 +279,8 @@ auto cg(size_t maxIter = LinearSolvers::defaultCgMaxIter, ...@@ -270,7 +279,8 @@ auto cg(size_t maxIter = LinearSolvers::defaultCgMaxIter,
auto cgSolver = std::bind( auto cgSolver = std::bind(
LinearSolvers::cg<std::decay_t<decltype(m)>, std::decay_t<decltype(b)>>, LinearSolvers::cg<std::decay_t<decltype(m)>, std::decay_t<decltype(b)>>,
_1, _2, maxIter, tol); _1, _2, maxIter, tol);
auto cgTruncatedSolver = LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver); auto cgTruncatedSolver =
LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver);
auto cgTruncatedRegularizedSolver = auto cgTruncatedRegularizedSolver =
LocalSolverRegularizer::diagRegularize(r, cgTruncatedSolver); LocalSolverRegularizer::diagRegularize(r, cgTruncatedSolver);
return cgTruncatedRegularizedSolver(m, b, ignore); return cgTruncatedRegularizedSolver(m, b, ignore);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment