From e232946be35b83775fd57bb4e3c207d4b35982e8 Mon Sep 17 00:00:00 2001
From: Elias Pipping <elias.pipping@fu-berlin.de>
Date: Tue, 13 Mar 2012 17:30:42 +0100
Subject: [PATCH] New: Dirichlet conditions; Clean up Neumann

---
 src/one-body-sample.cc         | 17 +++++++++++------
 src/one-body-sample.py         | 30 ++++++++++++++++++++++++++++++
 src/one-body-sample_neumann.py | 13 -------------
 3 files changed, 41 insertions(+), 19 deletions(-)
 create mode 100644 src/one-body-sample.py
 delete mode 100644 src/one-body-sample_neumann.py

diff --git a/src/one-body-sample.cc b/src/one-body-sample.cc
index 621639a9..998d5c1c 100644
--- a/src/one-body-sample.cc
+++ b/src/one-body-sample.cc
@@ -109,9 +109,8 @@ int main(int argc, char *argv[]) {
       Python::run("import sys");
       Python::run("sys.path.append('" srcdir "')");
 
-      Python::import("one-body-sample_neumann")
-          .get("Functions")
-          .toC<FunctionMap::Base>(functions);
+      Python::import("one-body-sample").get("Functions").toC<FunctionMap::Base>(
+          functions);
     }
 
     Dune::ParameterTree parset;
@@ -232,8 +231,14 @@ int main(int argc, char *argv[]) {
       if (parset.get<bool>("solver.tnnmg.use")) {
         assemble_neumann<GridType, GridView, SmallVector, P1Basis>(
             leafView, p1Basis, neumannNodes, b4,
-            functions.get("sampleFunction"), h * run);
+            functions.get("neumannCondition"), h * run);
         stiffnessMatrix.mmv(u4, b4);
+        // Apply Dirichlet condition
+        for (int i = 0; i < finestSize; ++i)
+          if (ignoreNodes[i].count() == dim)
+            functions.get("dirichletCondition")
+                .evaluate(h * run, u4_diff[i][0]);
+
         for (int state_fpi = 0;
              state_fpi < parset.get<int>("solver.tnnmg.fixed_point_iterations");
              ++state_fpi) {
@@ -267,7 +272,7 @@ int main(int argc, char *argv[]) {
 
         if (parset.get<bool>("printEvolution"))
           print_evolution<SingletonVectorType, VectorType>(
-              frictionalNodes, *s4_new, u4, functions.get("sampleFunction"),
+              frictionalNodes, *s4_new, u4, functions.get("neumannCondition"),
               run, h * run, octave_writer);
       }
 
@@ -277,7 +282,7 @@ int main(int argc, char *argv[]) {
       if (parset.get<bool>("benchmarks.fpi.enable")) {
         assemble_neumann<GridType, GridView, SmallVector, P1Basis>(
             leafView, p1Basis, neumannNodes, b5,
-            functions.get("sampleFunction"), h * run);
+            functions.get("neumannCondition"), h * run);
         stiffnessMatrix.mmv(u5, b5);
         for (int state_fpi = 0;
              state_fpi < parset.get<int>("benchmarks.fpi.iterations");
diff --git a/src/one-body-sample.py b/src/one-body-sample.py
new file mode 100644
index 00000000..0640a256
--- /dev/null
+++ b/src/one-body-sample.py
@@ -0,0 +1,30 @@
+class neumannCondition:
+    def __call__(self, x):
+        # return 0
+        fst = 0.3
+        snd = 0.1
+        trd = 0.3
+        if x < 1.0/3:
+            return fst * x
+        elif x < 2.0/3:
+            return snd * (x - 1.0/3) + fst * 1.0/3
+        else:
+            return trd * (x - 2.0/3) + (fst + snd) * 1.0/3
+
+class dirichletCondition:
+    def __call__(self, x):
+        return 0
+        # fst = 0.3e-3
+        # snd = 0.1e-3
+        # trd = 0.3e-3
+        # if x < 1.0/3:
+        #     return fst * x
+        # elif x < 2.0/3:
+        #     return snd * (x - 1.0/3) + fst * 1.0/3
+        # else:
+        #     return trd * (x - 2.0/3) + (fst + snd) * 1.0/3
+
+Functions = {
+    'neumannCondition' : neumannCondition(),
+    'dirichletCondition' : dirichletCondition()
+}
diff --git a/src/one-body-sample_neumann.py b/src/one-body-sample_neumann.py
deleted file mode 100644
index a553be95..00000000
--- a/src/one-body-sample_neumann.py
+++ /dev/null
@@ -1,13 +0,0 @@
-class sampleFunction:
-    def __call__(self, x):
-        fst = 0.3
-        snd = 0.1
-        trd = 0.3
-        if x < 1.0/3:
-            return fst * x
-        elif x < 2.0/3:
-            return snd * (x - 1.0/3) + fst * 1.0/3
-        else:
-            return trd * (x - 2.0/3) + (fst + snd) * 1.0/3
-
-Functions = {'sampleFunction' : sampleFunction()}
-- 
GitLab