Skip to content
Snippets Groups Projects
Commit 0e1afee4 authored by Elias Pipping's avatar Elias Pipping
Browse files

Add (lower|upper)TriangularSolve

parent 3dee0543
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -3,6 +3,8 @@
#ifndef STATIC_MATRIX_TOOL_HH
#define STATIC_MATRIX_TOOL_HH
#include <cassert>
#include "dune/common/diagonalmatrix.hh"
#include "dune/common/fmatrix.hh"
#include "dune/istl/scaledidmatrix.hh"
......@@ -458,7 +460,87 @@ class StaticMatrix
}
}
template <class Matrix, class Vector, class BitVector>
static void lowerTriangularSolve(Matrix const& A, Vector const& b,
Vector& x, BitVector const* ignore,
bool transpose = false) {
static_assert(
Matrix::block_type::rows == 1 and Matrix::block_type::cols == 1,
"Only implemented for scalar problems");
x = 0;
Vector r = b;
if (transpose) {
for (auto it = A.begin(); it != A.end(); ++it) {
const size_t i = it.index();
if (ignore != nullptr and (*ignore)[i].all())
continue;
auto cIt = it->begin();
assert(cIt.index() == it.index());
x[i] = r[i] / *cIt;
for (; cIt != it->end(); ++cIt) {
const size_t j = cIt.index();
r[j] -= x[i] * *cIt;
}
}
} else {
for (auto it = A.begin(); it != A.end(); ++it) {
const size_t i = it.index();
if (ignore != nullptr and (*ignore)[i].all())
continue;
for (auto cIt = it->begin(); cIt != it->end(); ++cIt) {
const size_t j = cIt.index();
if (i == j) {
x[i] = r[i] / *cIt;
break;
}
assert(j < i);
r[i] -= *cIt * x[j];
}
}
}
}
template <class Matrix, class Vector, class BitVector>
static void upperTriangularSolve(Matrix const& U, Vector const& b,
Vector& x, BitVector const* ignore,
bool transpose = false) {
static_assert(
Matrix::block_type::rows == 1 and Matrix::block_type::cols == 1,
"Only implemented for scalar problems");
x = 0;
Vector r = b;
if (transpose) {
for (auto it = U.beforeEnd(); it != U.beforeBegin(); --it) {
const size_t i = it.index();
if (ignore != nullptr and (*ignore)[i].all())
continue;
auto cIt = it->beforeEnd();
assert(cIt.index() == i);
x[i] = r[i] / *cIt;
cIt--;
for (; cIt != it->beforeBegin(); --cIt) {
const size_t j = cIt.index();
assert(j < i);
r[j] -= *cIt * x[i];
}
}
} else {
for (auto it = U.beforeEnd(); it != U.beforeBegin(); --it) {
const size_t i = it.index();
if (ignore != nullptr and (*ignore)[i].all())
continue;
auto cIt = it->begin();
assert(cIt.index() == i);
auto diagonal = *cIt;
cIt++;
for (; cIt != it->end(); ++cIt) {
const size_t j = cIt.index();
r[i] -= *cIt * x[j];
}
x[i] = r[i] / diagonal;
}
}
}
};
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment