Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
triangularsolve.hh 2.37 KiB
#ifndef DUNE_MATRIX_VECTOR_TRIANGULARSOLVE_HH
#define DUNE_MATRIX_VECTOR_TRIANGULARSOLVE_HH

namespace Dune {
namespace MatrixVector {
  // From http://stackoverflow.com/a/28139075/179927
  // Note: free rbegin()/rend() still missing gcc 4.9.x version of libstdc++
  namespace {
    template <typename T>
    struct reversion_wrapper {
      T& iterable;
    };

    template <typename T>
    auto begin(reversion_wrapper<T> w) {
      return w.iterable.rbegin();
    }

    template <typename T>
    auto end(reversion_wrapper<T> w) {
      return w.iterable.rend();
    }

    template <typename T>
    reversion_wrapper<T> reverse(T&& iterable) {
      return {iterable};
    }
  }

  template <class Matrix, class Vector, class BitVector>
  static void lowerTriangularSolve(Matrix const& A, Vector const& b, Vector& x,
                                   BitVector const* ignore,
                                   bool transpose = false) {
    x = 0;
    Vector r = b;
    if (transpose) {
      // TODO: not yet handled
    } else {
      for (auto it = A.begin(); it != A.end(); ++it) {
        size_t i = it.index();
        if (ignore != nullptr and (*ignore)[i].all())
          continue;
        for (auto cIt = it->begin(); cIt != it->end(); ++cIt) {
          size_t j = cIt.index();
          if (i == j) {
            x[i] = r[i] / *cIt;
            break;
          }
          assert(j < i);
          r[i] -= *cIt * x[j];
        }
      }
    }
  }

  template <class Matrix, class Vector, class BitVector>
  static void upperTriangularSolve(Matrix const& U, Vector const& b, Vector& x,
                                   BitVector const* ignore,
                                   bool transpose = false) {
    x = 0;
    Vector r = b;
    if (transpose) {
      // TODO: not yet handled
    } else {
      std::vector<typename Matrix::ConstIterator> rows(U.N());
      for (auto it = U.begin(); it != U.end(); ++it)
        rows[it.index()] = it;

      for (auto it : reverse(rows)) {
        size_t i = it.index();
        if (ignore != nullptr and (*ignore)[i].all())
          continue;
        auto cIt = it->begin();
        assert(cIt.index() == i);
        auto diagonal = *cIt;
        cIt++;
        for (; cIt != it->end(); ++cIt) {
          size_t j = cIt.index();
          assert(j > i);
          r[i] -= *cIt * x[j];
        }
        x[i] = r[i] / diagonal;
      }
    }
  }
}
}

#endif