Skip to content
Snippets Groups Projects
Commit 4187810f authored by Elias Pipping's avatar Elias Pipping
Browse files

Add lower/upper triangular solve

parent 2e362f56
No related branches found
No related tags found
No related merge requests found
...@@ -16,4 +16,5 @@ install(FILES ...@@ -16,4 +16,5 @@ install(FILES
singlenonzerorowmatrix.hh singlenonzerorowmatrix.hh
tranpose.hh tranpose.hh
transformmatrix.hh transformmatrix.hh
triangularsolve.hh
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dune/matrix-vector) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dune/matrix-vector)
\ No newline at end of file
#ifndef DUNE_MATRIX_VECTOR_TRIANGULARSOLVE_HH
#define DUNE_MATRIX_VECTOR_TRIANGULARSOLVE_HH
namespace Dune {
namespace MatrixVector {
// From http://stackoverflow.com/a/28139075/179927
// Note: free rbegin()/rend() still missing gcc 4.9.x version of libstdc++
namespace {
template <typename T>
struct reversion_wrapper {
T& iterable;
};
template <typename T>
auto begin(reversion_wrapper<T> w) {
return w.iterable.rbegin();
}
template <typename T>
auto end(reversion_wrapper<T> w) {
return w.iterable.rend();
}
template <typename T>
reversion_wrapper<T> reverse(T&& iterable) {
return {iterable};
}
}
template <class Matrix, class Vector, class BitVector>
static void lowerTriangularSolve(Matrix const& A, Vector const& b, Vector& x,
BitVector const* ignore,
bool transpose = false) {
x = 0;
Vector r = b;
if (transpose) {
// TODO: not yet handled
} else {
for (auto it = A.begin(); it != A.end(); ++it) {
size_t i = it.index();
if (ignore != nullptr and (*ignore)[i].all())
continue;
for (auto cIt = it->begin(); cIt != it->end(); ++cIt) {
size_t j = cIt.index();
if (i == j) {
x[i] = r[i] / *cIt;
break;
}
assert(j < i);
r[i] -= *cIt * x[j];
}
}
}
}
template <class Matrix, class Vector, class BitVector>
static void upperTriangularSolve(Matrix const& U, Vector const& b, Vector& x,
BitVector const* ignore,
bool transpose = false) {
x = 0;
Vector r = b;
if (transpose) {
// TODO: not yet handled
} else {
std::vector<typename Matrix::ConstIterator> rows(U.N());
for (auto it = U.begin(); it != U.end(); ++it)
rows[it.index()] = it;
for (auto it : reverse(rows)) {
size_t i = it.index();
if (ignore != nullptr and (*ignore)[i].all())
continue;
auto cIt = it->begin();
assert(cIt.index() == i);
auto diagonal = *cIt;
cIt++;
for (; cIt != it->end(); ++cIt) {
size_t j = cIt.index();
assert(j > i);
r[i] -= *cIt * x[j];
}
x[i] = r[i] / diagonal;
}
}
}
}
}
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment