From eaf14ab54d6ae3cd8f138f1ae9ea0b3d1b4f0672 Mon Sep 17 00:00:00 2001
From: Elias Pipping <elias.pipping@fu-berlin.de>
Date: Wed, 21 Jan 2015 14:41:57 +0100
Subject: [PATCH] [Cleanup] Move time stepping logic into parent class

While we're at it, reset postProcessCalled in nextTimeStep()
---
 src/timestepping.cc                | 45 ++++++++++++++
 src/timestepping.hh                | 29 +++++++--
 src/timestepping/backward_euler.cc | 85 +++++++++-----------------
 src/timestepping/backward_euler.hh | 19 +-----
 src/timestepping/newmark.cc        | 97 ++++++++++--------------------
 src/timestepping/newmark.hh        | 21 -------
 src/timestepping_tmpl.cc           |  1 +
 7 files changed, 132 insertions(+), 165 deletions(-)

diff --git a/src/timestepping.cc b/src/timestepping.cc
index 56346038..f5cd258d 100644
--- a/src/timestepping.cc
+++ b/src/timestepping.cc
@@ -4,6 +4,51 @@
 
 #include "timestepping.hh"
 
+template <class Vector, class Matrix, class Function, size_t dim>
+TimeSteppingScheme<Vector, Matrix, Function, dim>::TimeSteppingScheme(
+    Matrices<Matrix> const &_matrices, Vector const &_u_initial,
+    Vector const &_v_initial, Vector const &_a_initial,
+    Dune::BitSetVector<dim> const &_dirichletNodes,
+    Function const &_dirichletFunction)
+    : matrices(_matrices),
+      u(_u_initial),
+      v(_v_initial),
+      a(_a_initial),
+      dirichletNodes(_dirichletNodes),
+      dirichletFunction(_dirichletFunction) {}
+
+template <class Vector, class Matrix, class Function, size_t dim>
+void TimeSteppingScheme<Vector, Matrix, Function, dim>::nextTimeStep() {
+  u_o = u;
+  v_o = v;
+  a_o = a;
+  postProcessCalled = false;
+}
+
+template <class Vector, class Matrix, class Function, size_t dim>
+void TimeSteppingScheme<Vector, Matrix, Function, dim>::extractDisplacement(
+    Vector &displacement) const {
+  if (!postProcessCalled)
+    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");
+
+  displacement = u;
+}
+
+template <class Vector, class Matrix, class Function, size_t dim>
+void TimeSteppingScheme<Vector, Matrix, Function, dim>::extractVelocity(
+    Vector &velocity) const {
+  if (!postProcessCalled)
+    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");
+
+  velocity = v;
+}
+
+template <class Vector, class Matrix, class Function, size_t dim>
+void TimeSteppingScheme<Vector, Matrix, Function, dim>::extractOldVelocity(
+    Vector &oldVelocity) const {
+  oldVelocity = v_o;
+}
+
 #include "timestepping/backward_euler.cc"
 #include "timestepping/newmark.cc"
 
diff --git a/src/timestepping.hh b/src/timestepping.hh
index 55470075..cd6fec85 100644
--- a/src/timestepping.hh
+++ b/src/timestepping.hh
@@ -10,18 +10,37 @@
 
 template <class Vector, class Matrix, class Function, size_t dim>
 class TimeSteppingScheme {
+protected:
+  TimeSteppingScheme(Matrices<Matrix> const &_matrices,
+                     Vector const &_u_initial, Vector const &_v_initial,
+                     Vector const &_a_initial,
+                     Dune::BitSetVector<dim> const &_dirichletNodes,
+                     Function const &_dirichletFunction);
+
 public:
-  void virtual nextTimeStep() = 0;
+  void nextTimeStep();
   void virtual setup(Vector const &ell, double _tau, double relativeTime,
                      Vector &rhs, Vector &iterate, Matrix &AB) = 0;
 
   void virtual postProcess(Vector const &iterate) = 0;
-  void virtual extractDisplacement(Vector &displacement) const = 0;
-  void virtual extractVelocity(Vector &velocity) const = 0;
-  void virtual extractOldVelocity(Vector &velocity) const = 0;
+  void extractDisplacement(Vector &displacement) const;
+  void extractVelocity(Vector &velocity) const;
+  void extractOldVelocity(Vector &velocity) const;
 
   std::shared_ptr<TimeSteppingScheme<Vector, Matrix, Function,
                                      dim>> virtual clone() const = 0;
+
+protected:
+  Matrices<Matrix> const &matrices;
+  Vector u, v, a;
+  Dune::BitSetVector<dim> const &dirichletNodes;
+  Function const &dirichletFunction;
+  double dirichletValue;
+
+  Vector u_o, v_o, a_o;
+  double tau;
+
+  bool postProcessCalled = true;
 };
 
 #include "timestepping/newmark.hh"
@@ -42,7 +61,7 @@ initTimeStepper(Config::scheme scheme,
     case Config::BackwardEuler:
       return std::make_shared<
           BackwardEuler<Vector, Matrix, Function, dimension>>(
-          matrices, u_initial, v_initial, velocityDirichletNodes,
+          matrices, u_initial, v_initial, a_initial, velocityDirichletNodes,
           velocityDirichletFunction);
     default:
       assert(false);
diff --git a/src/timestepping/backward_euler.cc b/src/timestepping/backward_euler.cc
index f3e51588..e9d67219 100644
--- a/src/timestepping/backward_euler.cc
+++ b/src/timestepping/backward_euler.cc
@@ -3,27 +3,19 @@
 template <class Vector, class Matrix, class Function, size_t dim>
 BackwardEuler<Vector, Matrix, Function, dim>::BackwardEuler(
     Matrices<Matrix> const &_matrices, Vector const &_u_initial,
-    Vector const &_v_initial, Dune::BitSetVector<dim> const &_dirichletNodes,
+    Vector const &_v_initial, Vector const &_a_initial,
+    Dune::BitSetVector<dim> const &_dirichletNodes,
     Function const &_dirichletFunction)
-    : matrices(_matrices),
-      u(_u_initial),
-      v(_v_initial),
-      dirichletNodes(_dirichletNodes),
-      dirichletFunction(_dirichletFunction) {}
-
-template <class Vector, class Matrix, class Function, size_t dim>
-void BackwardEuler<Vector, Matrix, Function, dim>::nextTimeStep() {
-  v_o = v;
-  u_o = u;
-}
+    : TimeSteppingScheme<Vector, Matrix, Function, dim>(
+          _matrices, _u_initial, _v_initial, _a_initial, _dirichletNodes,
+          _dirichletFunction) {}
 
 template <class Vector, class Matrix, class Function, size_t dim>
 void BackwardEuler<Vector, Matrix, Function, dim>::setup(
     Vector const &ell, double _tau, double relativeTime, Vector &rhs,
     Vector &iterate, Matrix &AM) {
-  postProcessCalled = false;
-  dirichletFunction.evaluate(relativeTime, dirichletValue);
-  tau = _tau;
+  this->dirichletFunction.evaluate(relativeTime, this->dirichletValue);
+  this->tau = _tau;
 
   /* We start out with the formulation
 
@@ -48,66 +40,47 @@ void BackwardEuler<Vector, Matrix, Function, dim>::setup(
 
   // set up LHS (for fixed tau, we'd only really have to do this once)
   {
-    Dune::MatrixIndexSet indices(matrices.elasticity.N(),
-                                 matrices.elasticity.M());
-    indices.import(matrices.elasticity);
-    indices.import(matrices.mass);
-    indices.import(matrices.damping);
+    Dune::MatrixIndexSet indices(this->matrices.elasticity.N(),
+                                 this->matrices.elasticity.M());
+    indices.import(this->matrices.elasticity);
+    indices.import(this->matrices.mass);
+    indices.import(this->matrices.damping);
     indices.exportIdx(AM);
   }
   AM = 0.0;
-  Arithmetic::addProduct(AM, 1.0 / tau, matrices.mass);
-  Arithmetic::addProduct(AM, 1.0, matrices.damping);
-  Arithmetic::addProduct(AM, tau, matrices.elasticity);
+  Arithmetic::addProduct(AM, 1.0 / this->tau, this->matrices.mass);
+  Arithmetic::addProduct(AM, 1.0, this->matrices.damping);
+  Arithmetic::addProduct(AM, this->tau, this->matrices.elasticity);
 
   // set up RHS
   {
     rhs = ell;
-    Arithmetic::addProduct(rhs, 1.0 / tau, matrices.mass, v_o);
-    Arithmetic::subtractProduct(rhs, matrices.elasticity, u_o);
+    Arithmetic::addProduct(rhs, 1.0 / this->tau, this->matrices.mass,
+                           this->v_o);
+    Arithmetic::subtractProduct(rhs, this->matrices.elasticity, this->u_o);
   }
 
-  iterate = v_o;
+  iterate = this->v_o;
 
-  for (size_t i = 0; i < dirichletNodes.size(); ++i)
+  for (size_t i = 0; i < this->dirichletNodes.size(); ++i)
     for (size_t j = 0; j < dim; ++j)
-      if (dirichletNodes[i][j])
-        iterate[i][j] = (j == 0) ? dirichletValue : 0;
+      if (this->dirichletNodes[i][j])
+        iterate[i][j] = (j == 0) ? this->dirichletValue : 0;
 }
 
 template <class Vector, class Matrix, class Function, size_t dim>
 void BackwardEuler<Vector, Matrix, Function, dim>::postProcess(
     Vector const &iterate) {
-  postProcessCalled = true;
+  this->postProcessCalled = true;
 
-  v = iterate;
+  this->v = iterate;
 
-  u = u_o;
-  Arithmetic::addProduct(u, tau, v);
-}
+  this->u = this->u_o;
+  Arithmetic::addProduct(this->u, this->tau, this->v);
 
-template <class Vector, class Matrix, class Function, size_t dim>
-void BackwardEuler<Vector, Matrix, Function, dim>::extractDisplacement(
-    Vector &displacement) const {
-  if (!postProcessCalled)
-    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");
-
-  displacement = u;
-}
-
-template <class Vector, class Matrix, class Function, size_t dim>
-void BackwardEuler<Vector, Matrix, Function, dim>::extractVelocity(
-    Vector &velocity) const {
-  if (!postProcessCalled)
-    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");
-
-  velocity = v;
-}
-
-template <class Vector, class Matrix, class Function, size_t dim>
-void BackwardEuler<Vector, Matrix, Function, dim>::extractOldVelocity(
-    Vector &velocity) const {
-  velocity = v_o;
+  this->a = this->v;
+  this->a -= this->v_o;
+  this->a /= this->tau;
 }
 
 template <class Vector, class Matrix, class Function, size_t dim>
diff --git a/src/timestepping/backward_euler.hh b/src/timestepping/backward_euler.hh
index db1a8402..09ec584d 100644
--- a/src/timestepping/backward_euler.hh
+++ b/src/timestepping/backward_euler.hh
@@ -5,34 +5,17 @@ template <class Vector, class Matrix, class Function, size_t dim>
 class BackwardEuler : public TimeSteppingScheme<Vector, Matrix, Function, dim> {
 public:
   BackwardEuler(Matrices<Matrix> const &_matrices, Vector const &_u_initial,
-                Vector const &_v_initial,
+                Vector const &_v_initial, Vector const &_a_initial,
                 Dune::BitSetVector<dim> const &_dirichletNodes,
                 Function const &_dirichletFunction);
 
-  void nextTimeStep() override;
   void setup(Vector const &, double, double, Vector &, Vector &,
              Matrix &) override;
   void postProcess(Vector const &) override;
-  void extractDisplacement(Vector &) const override;
-  void extractVelocity(Vector &) const override;
-  void extractOldVelocity(Vector &) const override;
 
   std::shared_ptr<TimeSteppingScheme<Vector, Matrix, Function, dim>> clone()
       const;
 
 private:
-  Matrices<Matrix> const &matrices;
-  Vector u;
-  Vector v;
-  Dune::BitSetVector<dim> const &dirichletNodes;
-  Function const &dirichletFunction;
-  double dirichletValue;
-
-  Vector u_o;
-  Vector v_o;
-
-  double tau;
-
-  bool postProcessCalled = true;
 };
 #endif
diff --git a/src/timestepping/newmark.cc b/src/timestepping/newmark.cc
index ee81b310..227f8ba1 100644
--- a/src/timestepping/newmark.cc
+++ b/src/timestepping/newmark.cc
@@ -6,19 +6,9 @@ Newmark<Vector, Matrix, Function, dim>::Newmark(
     Vector const &_v_initial, Vector const &_a_initial,
     Dune::BitSetVector<dim> const &_dirichletNodes,
     Function const &_dirichletFunction)
-    : matrices(_matrices),
-      u(_u_initial),
-      v(_v_initial),
-      a(_a_initial),
-      dirichletNodes(_dirichletNodes),
-      dirichletFunction(_dirichletFunction) {}
-
-template <class Vector, class Matrix, class Function, size_t dim>
-void Newmark<Vector, Matrix, Function, dim>::nextTimeStep() {
-  a_o = a;
-  v_o = v;
-  u_o = u;
-}
+    : TimeSteppingScheme<Vector, Matrix, Function, dim>(
+          _matrices, _u_initial, _v_initial, _a_initial, _dirichletNodes,
+          _dirichletFunction) {}
 
 template <class Vector, class Matrix, class Function, size_t dim>
 void Newmark<Vector, Matrix, Function, dim>::setup(Vector const &ell,
@@ -26,9 +16,8 @@ void Newmark<Vector, Matrix, Function, dim>::setup(Vector const &ell,
                                                    double relativeTime,
                                                    Vector &rhs, Vector &iterate,
                                                    Matrix &AM) {
-  postProcessCalled = false;
-  dirichletFunction.evaluate(relativeTime, dirichletValue);
-  tau = _tau;
+  this->dirichletFunction.evaluate(relativeTime, this->dirichletValue);
+  this->tau = _tau;
 
   /* We start out with the formulation
 
@@ -54,76 +43,54 @@ void Newmark<Vector, Matrix, Function, dim>::setup(Vector const &ell,
 
   // set up LHS (for fixed tau, we'd only really have to do this once)
   {
-    Dune::MatrixIndexSet indices(matrices.elasticity.N(),
-                                 matrices.elasticity.M());
-    indices.import(matrices.elasticity);
-    indices.import(matrices.mass);
-    indices.import(matrices.damping);
+    Dune::MatrixIndexSet indices(this->matrices.elasticity.N(),
+                                 this->matrices.elasticity.M());
+    indices.import(this->matrices.elasticity);
+    indices.import(this->matrices.mass);
+    indices.import(this->matrices.damping);
     indices.exportIdx(AM);
   }
   AM = 0.0;
-  Arithmetic::addProduct(AM, 2.0 / tau, matrices.mass);
-  Arithmetic::addProduct(AM, 1.0, matrices.damping);
-  Arithmetic::addProduct(AM, tau / 2.0, matrices.elasticity);
+  Arithmetic::addProduct(AM, 2.0 / this->tau, this->matrices.mass);
+  Arithmetic::addProduct(AM, 1.0, this->matrices.damping);
+  Arithmetic::addProduct(AM, this->tau / 2.0, this->matrices.elasticity);
 
   // set up RHS
   {
     rhs = ell;
-    Arithmetic::addProduct(rhs, 2.0 / tau, matrices.mass, v_o);
-    Arithmetic::addProduct(rhs, matrices.mass, a_o);
-    Arithmetic::subtractProduct(rhs, tau / 2.0, matrices.elasticity, v_o);
-    Arithmetic::subtractProduct(rhs, matrices.elasticity, u_o);
+    Arithmetic::addProduct(rhs, 2.0 / this->tau, this->matrices.mass,
+                           this->v_o);
+    Arithmetic::addProduct(rhs, this->matrices.mass, this->a_o);
+    Arithmetic::subtractProduct(rhs, this->tau / 2.0, this->matrices.elasticity,
+                                this->v_o);
+    Arithmetic::subtractProduct(rhs, this->matrices.elasticity, this->u_o);
   }
 
-  iterate = v_o;
+  iterate = this->v_o;
 
-  for (size_t i = 0; i < dirichletNodes.size(); ++i)
+  for (size_t i = 0; i < this->dirichletNodes.size(); ++i)
     for (size_t j = 0; j < dim; ++j)
-      if (dirichletNodes[i][j])
-        iterate[i][j] = (j == 0) ? dirichletValue : 0;
+      if (this->dirichletNodes[i][j])
+        iterate[i][j] = (j == 0) ? this->dirichletValue : 0;
 }
 
 template <class Vector, class Matrix, class Function, size_t dim>
 void Newmark<Vector, Matrix, Function, dim>::postProcess(
     Vector const &iterate) {
-  postProcessCalled = true;
+  this->postProcessCalled = true;
 
-  v = iterate;
+  this->v = iterate;
 
   // u1 = tau/2 ( v1 + v0 ) + u0
-  u = u_o;
-  Arithmetic::addProduct(u, tau / 2.0, v);
-  Arithmetic::addProduct(u, tau / 2.0, v_o);
+  this->u = this->u_o;
+  Arithmetic::addProduct(this->u, this->tau / 2.0, this->v);
+  Arithmetic::addProduct(this->u, this->tau / 2.0, this->v_o);
 
   // a1 = 2/tau ( v1 - v0 ) - a0
-  a = 0;
-  Arithmetic::addProduct(a, 2.0 / tau, v);
-  Arithmetic::subtractProduct(a, 2.0 / tau, v_o);
-  Arithmetic::subtractProduct(a, 1.0, a_o);
-}
-
-template <class Vector, class Matrix, class Function, size_t dim>
-void Newmark<Vector, Matrix, Function, dim>::extractDisplacement(
-    Vector &displacement) const {
-  if (!postProcessCalled)
-    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");
-
-  displacement = u;
-}
-
-template <class Vector, class Matrix, class Function, size_t dim>
-void Newmark<Vector, Matrix, Function, dim>::extractVelocity(
-    Vector &velocity) const {
-  if (!postProcessCalled)
-    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");
-
-  velocity = v;
-}
-
-template <class Vector, class Matrix, class Function, size_t dim>
-void Newmark<Vector, Matrix, Function, dim>::extractOldVelocity(
-    Vector &velocity) const {
-  velocity = v_o;
+  this->a = 0.0;
+  Arithmetic::addProduct(this->a, 2.0 / this->tau, this->v);
+  Arithmetic::subtractProduct(this->a, 2.0 / this->tau, this->v_o);
+  Arithmetic::subtractProduct(this->a, 1.0, this->a_o);
 }
 
 template <class Vector, class Matrix, class Function, size_t dim>
diff --git a/src/timestepping/newmark.hh b/src/timestepping/newmark.hh
index 72ae8543..f2a5d566 100644
--- a/src/timestepping/newmark.hh
+++ b/src/timestepping/newmark.hh
@@ -9,32 +9,11 @@ class Newmark : public TimeSteppingScheme<Vector, Matrix, Function, dim> {
           Dune::BitSetVector<dim> const &_dirichletNodes,
           Function const &_dirichletFunction);
 
-  void nextTimeStep() override;
   void setup(Vector const &, double, double, Vector &, Vector &,
              Matrix &) override;
   void postProcess(Vector const &) override;
-  void extractDisplacement(Vector &) const override;
-  void extractVelocity(Vector &) const override;
-  void extractOldVelocity(Vector &) const override;
 
   std::shared_ptr<TimeSteppingScheme<Vector, Matrix, Function, dim>> clone()
       const;
-
-private:
-  Matrices<Matrix> const &matrices;
-  Vector u;
-  Vector v;
-  Vector a;
-  Dune::BitSetVector<dim> const &dirichletNodes;
-  Function const &dirichletFunction;
-  double dirichletValue;
-
-  Vector u_o;
-  Vector v_o;
-  Vector a_o;
-
-  double tau;
-
-  bool postProcessCalled = true;
 };
 #endif
diff --git a/src/timestepping_tmpl.cc b/src/timestepping_tmpl.cc
index a61958b8..b9e86274 100644
--- a/src/timestepping_tmpl.cc
+++ b/src/timestepping_tmpl.cc
@@ -8,5 +8,6 @@
 
 using Function = Dune::VirtualFunction<double, double>;
 
+template class TimeSteppingScheme<Vector, Matrix, Function, MY_DIM>;
 template class Newmark<Vector, Matrix, Function, MY_DIM>;
 template class BackwardEuler<Vector, Matrix, Function, MY_DIM>;
-- 
GitLab