#ifndef DUNE_MATRIX_VECTOR_TRIANGULARSOLVE_HH
#define DUNE_MATRIX_VECTOR_TRIANGULARSOLVE_HH

namespace Dune {
namespace MatrixVector {

  /**
   * \brief Solves L*x=b where L is a lower triangular matrix.
   * \param x Solution vector. Make sure the ignored nodes are set
   *        correctly because their respective values for x as passed to
   *        this method may affect other lines.
   * \param ignore Tags the lines of the system where x shall not be
   *        computed. Note that ignored nodes might still influence the
   *        solution of other lines if not set to zero.
   * \param transpose Set true if L is passed as an upper triangular matrix.
   */
  // Note: The current assert makes sure we actually deal with triangular
  //       matrices. It might be favorable to support matrices as well
  //       where other entries are not required to be non-existent, e.g.
  //       if one has a in-place decomposition and does not want to
  //       reimplement the iterators.
  template <class Matrix, class Vector, class BitVector>
  static void lowerTriangularSolve(Matrix const& L, Vector b, Vector& x,
                                   BitVector const* ignore,
                                   bool transpose = false) {
    if (transpose) {
      for (auto it = L.begin(); it != L.end(); ++it) {
        const size_t i = it.index();
        auto cIt = it->begin();
        assert(cIt.index() == i);
        if (ignore == nullptr or (*ignore)[cIt.index()].none())
          x[i] = b[i] / *cIt;
        cIt++;
        for (; cIt != it->end(); ++cIt) {
          const size_t j = cIt.index();
          assert(j > i);
          // Note: We could drop the check for ignore nodes here bcs. b[j] will
          // be ignored anyway due to the check above.
          if (ignore == nullptr or (*ignore)[j].none())
            b[j] -= x[i] * *cIt;
        }
      }
    } else {
      for (auto it = L.begin(); it != L.end(); ++it) {
        const size_t i = it.index();
        if (ignore != nullptr and (*ignore)[i].all())
          continue;
        for (auto cIt = it->begin(); cIt != it->end(); ++cIt) {
          const size_t j = cIt.index();
          assert(j <= i);
          if (j < i) {
            b[i] -= *cIt * x[j];
            continue;
          }
          assert(j == i);
          x[i] = b[i] / *cIt;
        }
      }
    }
  }

  /**
   * \brief Same as lowerTriangularSolve. Ignored nodes are set to 0.
   */
  template <class Matrix, class Vector, class BitVector>
  static Vector lowerTriangularSolve(Matrix const& L, Vector b,
                                     BitVector const* ignore,
                                     bool transpose = false) {
    Vector x = b;
    x = 0;
    lowerTriangularSolve(L, std::move(b), x, ignore, transpose);
    return x;
  }

  /**
   * \brief Solves U*x=b where U is an upper triangular matrix.
   *        See @lowerTriangularSolve for details.
   */
  template <class Matrix, class Vector, class BitVector>
  static void upperTriangularSolve(Matrix const& U, Vector b, Vector& x,
                                   BitVector const* ignore,
                                   bool transpose = false) {
    if (transpose) {
      for (auto it = U.beforeEnd(); it != U.beforeBegin(); --it) {
        const size_t i = it.index();
        auto cIt = it->beforeEnd();
        assert(cIt.index() == i);
        if (ignore == nullptr or (*ignore)[cIt.index()].none())
          x[i] = b[i] / *cIt;
        cIt--;
        for (; cIt != it->beforeBegin(); --cIt) {
          const size_t j = cIt.index();
          assert(j < i);
          // Note: We could drop the check for ignore nodes here bcs. b[j] will
          // be ignored anyway due to the check above.
          if (ignore == nullptr or (*ignore)[j].none())
            b[j] -= *cIt * x[i];
        }
      }
    } else {
      for (auto it = U.beforeEnd(); it != U.beforeBegin(); --it) {
        const size_t i = it.index();
        if (ignore != nullptr and (*ignore)[i].all())
          continue;
        auto cIt = it->beforeEnd();
        for (; cIt != it->begin(); --cIt) {
          const size_t j = cIt.index();
          assert(j > i);
          b[i] -= *cIt * x[j];
        }
        assert(cIt.index() == i);
        x[i] = b[i] / *cIt;
      }
    }
  }

  /**
   * \brief Same as upperTriangularSolve. Ignored nodes are set to 0.
   */
  template <class Matrix, class Vector, class BitVector>
  static Vector upperTriangularSolve(Matrix const& U, Vector b,
                                     BitVector const* ignore,
                                     bool transpose = false) {
    Vector x = b;
    x = 0;
    upperTriangularSolve(U, std::move(b), x, ignore, transpose);
    return x;
  }

} // end namespace MatrixVector
} // end namespace Dune

#endif // DUNE_MATRIX_VECTOR_TRIANGULARSOLVE_HH