From 0e1afee43e74c505335e068373f41dae922c40f7 Mon Sep 17 00:00:00 2001
From: Elias Pipping <elias.pipping@fu-berlin.de>
Date: Sat, 16 Jul 2016 20:56:11 +0200
Subject: [PATCH] Add (lower|upper)TriangularSolve

---
 dune/solvers/common/staticmatrixtools.hh | 84 +++++++++++++++++++++++-
 1 file changed, 83 insertions(+), 1 deletion(-)

diff --git a/dune/solvers/common/staticmatrixtools.hh b/dune/solvers/common/staticmatrixtools.hh
index 70bd60ca..f95c5791 100644
--- a/dune/solvers/common/staticmatrixtools.hh
+++ b/dune/solvers/common/staticmatrixtools.hh
@@ -3,6 +3,8 @@
 #ifndef STATIC_MATRIX_TOOL_HH
 #define STATIC_MATRIX_TOOL_HH
 
+#include <cassert>
+
 #include "dune/common/diagonalmatrix.hh"
 #include "dune/common/fmatrix.hh"
 #include "dune/istl/scaledidmatrix.hh"
@@ -458,7 +460,87 @@ class StaticMatrix
           }
         }
 
+        template <class Matrix, class Vector, class BitVector>
+        static void lowerTriangularSolve(Matrix const& A, Vector const& b,
+                                         Vector& x, BitVector const* ignore,
+                                         bool transpose = false) {
+          static_assert(
+              Matrix::block_type::rows == 1 and Matrix::block_type::cols == 1,
+              "Only implemented for scalar problems");
+          x = 0;
+          Vector r = b;
+          if (transpose) {
+            for (auto it = A.begin(); it != A.end(); ++it) {
+              const size_t i = it.index();
+              if (ignore != nullptr and (*ignore)[i].all())
+                continue;
+              auto cIt = it->begin();
+              assert(cIt.index() == it.index());
+              x[i] = r[i] / *cIt;
+              for (; cIt != it->end(); ++cIt) {
+                const size_t j = cIt.index();
+                r[j] -= x[i] * *cIt;
+              }
+            }
+          } else {
+            for (auto it = A.begin(); it != A.end(); ++it) {
+              const size_t i = it.index();
+              if (ignore != nullptr and (*ignore)[i].all())
+                continue;
+              for (auto cIt = it->begin(); cIt != it->end(); ++cIt) {
+                const size_t j = cIt.index();
+                if (i == j) {
+                  x[i] = r[i] / *cIt;
+                  break;
+                }
+                assert(j < i);
+                r[i] -= *cIt * x[j];
+              }
+            }
+          }
+        }
+
+        template <class Matrix, class Vector, class BitVector>
+        static void upperTriangularSolve(Matrix const& U, Vector const& b,
+                                         Vector& x, BitVector const* ignore,
+                                         bool transpose = false) {
+          static_assert(
+              Matrix::block_type::rows == 1 and Matrix::block_type::cols == 1,
+              "Only implemented for scalar problems");
+          x = 0;
+          Vector r = b;
+          if (transpose) {
+            for (auto it = U.beforeEnd(); it != U.beforeBegin(); --it) {
+              const size_t i = it.index();
+              if (ignore != nullptr and (*ignore)[i].all())
+                continue;
+              auto cIt = it->beforeEnd();
+              assert(cIt.index() == i);
+              x[i] = r[i] / *cIt;
+              cIt--;
+              for (; cIt != it->beforeBegin(); --cIt) {
+                const size_t j = cIt.index();
+                assert(j < i);
+                r[j] -= *cIt * x[i];
+              }
+            }
+          } else {
+            for (auto it = U.beforeEnd(); it != U.beforeBegin(); --it) {
+              const size_t i = it.index();
+              if (ignore != nullptr and (*ignore)[i].all())
+                continue;
+              auto cIt = it->begin();
+              assert(cIt.index() == i);
+              auto diagonal = *cIt;
+              cIt++;
+              for (; cIt != it->end(); ++cIt) {
+                const size_t j = cIt.index();
+                r[i] -= *cIt * x[j];
+              }
+              x[i] = r[i] / diagonal;
+            }
+          }
+        }
 };
 
 #endif
-
-- 
GitLab