diff --git a/dune/matrix-vector/CMakeLists.txt b/dune/matrix-vector/CMakeLists.txt index 07a1c2cbffa0a361f3f64fcac212a81a3a0d2b2e..9acfb4a5f0afce3f98809d63523037308223a809 100644 --- a/dune/matrix-vector/CMakeLists.txt +++ b/dune/matrix-vector/CMakeLists.txt @@ -16,4 +16,5 @@ install(FILES singlenonzerorowmatrix.hh tranpose.hh transformmatrix.hh + triangularsolve.hh DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dune/matrix-vector) \ No newline at end of file diff --git a/dune/matrix-vector/triangularsolve.hh b/dune/matrix-vector/triangularsolve.hh new file mode 100644 index 0000000000000000000000000000000000000000..46012ae249e8408757327c4d0763167273917181 --- /dev/null +++ b/dune/matrix-vector/triangularsolve.hh @@ -0,0 +1,89 @@ +#ifndef DUNE_MATRIX_VECTOR_TRIANGULARSOLVE_HH +#define DUNE_MATRIX_VECTOR_TRIANGULARSOLVE_HH + +namespace Dune { +namespace MatrixVector { + // From http://stackoverflow.com/a/28139075/179927 + // Note: free rbegin()/rend() still missing gcc 4.9.x version of libstdc++ + namespace { + template <typename T> + struct reversion_wrapper { + T& iterable; + }; + + template <typename T> + auto begin(reversion_wrapper<T> w) { + return w.iterable.rbegin(); + } + + template <typename T> + auto end(reversion_wrapper<T> w) { + return w.iterable.rend(); + } + + template <typename T> + reversion_wrapper<T> reverse(T&& iterable) { + return {iterable}; + } + } + + template <class Matrix, class Vector, class BitVector> + static void lowerTriangularSolve(Matrix const& A, Vector const& b, Vector& x, + BitVector const* ignore, + bool transpose = false) { + x = 0; + Vector r = b; + if (transpose) { + // TODO: not yet handled + } else { + for (auto it = A.begin(); it != A.end(); ++it) { + size_t i = it.index(); + if (ignore != nullptr and (*ignore)[i].all()) + continue; + for (auto cIt = it->begin(); cIt != it->end(); ++cIt) { + size_t j = cIt.index(); + if (i == j) { + x[i] = r[i] / *cIt; + break; + } + assert(j < i); + r[i] -= *cIt * x[j]; + } + } + } + } + + template <class Matrix, class Vector, class BitVector> + static void upperTriangularSolve(Matrix const& U, Vector const& b, Vector& x, + BitVector const* ignore, + bool transpose = false) { + x = 0; + Vector r = b; + if (transpose) { + // TODO: not yet handled + } else { + std::vector<typename Matrix::ConstIterator> rows(U.N()); + for (auto it = U.begin(); it != U.end(); ++it) + rows[it.index()] = it; + + for (auto it : reverse(rows)) { + size_t i = it.index(); + if (ignore != nullptr and (*ignore)[i].all()) + continue; + auto cIt = it->begin(); + assert(cIt.index() == i); + auto diagonal = *cIt; + cIt++; + for (; cIt != it->end(); ++cIt) { + size_t j = cIt.index(); + assert(j > i); + r[i] -= *cIt * x[j]; + } + x[i] = r[i] / diagonal; + } + } + } +} +} + +#endif