From 431a56cfa53af933c4df3537e00017d4a470598b Mon Sep 17 00:00:00 2001
From: Elias Pipping <elias.pipping@fu-berlin.de>
Date: Tue, 20 Jan 2015 16:49:11 +0100
Subject: [PATCH] [Cleanup] Externalise setup of initial conditions

---
 src/adaptivetimestepper.hh |   5 +-
 src/assemblers.hh          |   2 +-
 src/program_state.hh       | 113 +++++++++++++++++++++++++++++++++++++
 src/sand-wedge.cc          | 113 +++++++++----------------------------
 4 files changed, 145 insertions(+), 88 deletions(-)
 create mode 100644 src/program_state.hh

diff --git a/src/adaptivetimestepper.hh b/src/adaptivetimestepper.hh
index a6864a7c..9720efa6 100644
--- a/src/adaptivetimestepper.hh
+++ b/src/adaptivetimestepper.hh
@@ -20,12 +20,13 @@ class AdaptiveTimeStepper {
   AdaptiveTimeStepper(
       Factory &factory, Dune::ParameterTree const &parset,
       std::shared_ptr<Nonlinearity> globalFriction, UpdaterPair &current,
+      double relativeTime, double relativeTau,
       std::function<void(double, Vector &)> externalForces,
       ErrorNorm const &errorNorm,
       std::function<bool(UpdaterPair &, UpdaterPair &)> mustRefine)
       : finalTime_(parset.get<double>("problem.finalTime")),
-        relativeTime_(0.0),
-        relativeTau_(1e-6), // FIXME (not really important, though)
+        relativeTime_(relativeTime),
+        relativeTau_(relativeTau),
         factory_(factory),
         parset_(parset),
         globalFriction_(globalFriction),
diff --git a/src/assemblers.hh b/src/assemblers.hh
index 14abad33..9fd600fb 100644
--- a/src/assemblers.hh
+++ b/src/assemblers.hh
@@ -31,6 +31,7 @@ template <class GridView, int dimension> class MyAssembler {
 
   CellBasis const cellBasis;
   VertexBasis const vertexBasis;
+  GridView const &gridView;
 
 private:
   using Grid = typename GridView::Grid;
@@ -40,7 +41,6 @@ template <class GridView, int dimension> class MyAssembler {
   using LocalCellBasis = typename CellBasis::LocalFiniteElement;
   using LocalVertexBasis = typename VertexBasis::LocalFiniteElement;
 
-  GridView const &gridView;
   Assembler<CellBasis, CellBasis> cellAssembler;
   Assembler<VertexBasis, VertexBasis> vertexAssembler;
 
diff --git a/src/program_state.hh b/src/program_state.hh
new file mode 100644
index 00000000..f05e193c
--- /dev/null
+++ b/src/program_state.hh
@@ -0,0 +1,113 @@
+#ifndef SRC_PROGRAM_STATE_HH
+#define SRC_PROGRAM_STATE_HH
+
+#include <dune/common/parametertree.hh>
+
+#include <dune/fufem/boundarypatch.hh>
+
+#include <dune/tectonic/body.hh>
+
+#include "assemblers.hh"
+#include "matrices.hh"
+#include "solverfactory.hh"
+
+template <class Vector, class ScalarVector> class ProgramState {
+public:
+  ProgramState(int leafVertexCount)
+      : u(leafVertexCount),
+        v(leafVertexCount),
+        a(leafVertexCount),
+        alpha(leafVertexCount),
+        normalStress(leafVertexCount) {}
+
+  // Set up initial conditions
+  template <class Matrix, class GridView>
+  void setupInitialConditions(
+      Dune::ParameterTree const &parset,
+      std::function<void(double, Vector &)> externalForces,
+      Matrices<Matrix> const matrices,
+      MyAssembler<GridView, Vector::block_type::dimension> const &myAssembler,
+      Dune::BitSetVector<Vector::block_type::dimension> const &dirichletNodes,
+      Dune::BitSetVector<Vector::block_type::dimension> const &noNodes,
+      BoundaryPatch<GridView> const &frictionalBoundary,
+      Body<Vector::block_type::dimension> const &body) {
+
+    using LocalVector = typename Vector::block_type;
+    using LocalMatrix = typename Matrix::block_type;
+    auto const dims = LocalVector::dimension;
+
+    // Solving a linear problem with a multigrid solver
+    auto const solveLinearProblem =
+        [&](Dune::BitSetVector<dims> const &_dirichletNodes,
+            Matrix const &_matrix, Vector const &_rhs, Vector &_x,
+            Dune::ParameterTree const &_localParset) {
+
+      using LinearFactory = SolverFactory<
+          dims, BlockNonlinearTNNMGProblem<ConvexProblem<
+                    ZeroNonlinearity<LocalVector, LocalMatrix>, Matrix>>,
+          typename GridView::Grid>;
+      ZeroNonlinearity<LocalVector, LocalMatrix> zeroNonlinearity;
+
+      LinearFactory factory(parset.sub("solver.tnnmg"), // FIXME
+                            myAssembler.gridView.grid(), _dirichletNodes);
+
+      typename LinearFactory::ConvexProblem convexProblem(
+          1.0, _matrix, zeroNonlinearity, _rhs, _x);
+      typename LinearFactory::BlockProblem problem(parset, convexProblem);
+
+      auto multigridStep = factory.getStep();
+      multigridStep->setProblem(_x, problem);
+      EnergyNorm<Matrix, Vector> const norm(_matrix);
+      LoopSolver<Vector> solver(
+          multigridStep.get(), _localParset.get<size_t>("maximumIterations"),
+          _localParset.get<double>("tolerance"), &norm,
+          _localParset.get<Solver::VerbosityMode>("verbosity"),
+          false); // absolute error
+
+      solver.preprocess();
+      solver.solve();
+    };
+
+    timeStep = 1;
+    relativeTime = 0.0;
+    relativeTau = 1e-6;
+
+    Vector ell0(u.size());
+    externalForces(relativeTime, ell0);
+
+    // Initial velocity
+    v = 0.0;
+
+    // Initial displacement: Start from a situation of minimal stress,
+    // which is automatically attained in the case [v = 0 = a].
+    // Assuming dPhi(v = 0) = 0, we thus only have to solve Au = ell0
+    solveLinearProblem(dirichletNodes, matrices.elasticity, ell0, u,
+                       parset.sub("u0.solver"));
+
+    // Initial acceleration: Computed in agreement with Ma = ell0 - Au
+    // (without Dirichlet constraints), again assuming dPhi(v = 0) = 0
+    Vector accelerationRHS = ell0;
+    Arithmetic::subtractProduct(accelerationRHS, matrices.elasticity, u);
+    solveLinearProblem(noNodes, matrices.mass, accelerationRHS, a,
+                       parset.sub("a0.solver"));
+
+    // Initial state
+    alpha = parset.get<double>("boundary.friction.initialAlpha");
+
+    // Initial normal stress
+    myAssembler.assembleNormalStress(frictionalBoundary, normalStress,
+                                     body.getYoungModulus(),
+                                     body.getPoissonRatio(), u);
+  }
+
+public:
+  Vector u;
+  Vector v;
+  Vector a;
+  ScalarVector alpha;
+  ScalarVector normalStress;
+  double relativeTime;
+  double relativeTau;
+  size_t timeStep;
+};
+#endif
diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc
index dab96e4f..3942ef1f 100644
--- a/src/sand-wedge.cc
+++ b/src/sand-wedge.cc
@@ -60,6 +60,7 @@
 #include "friction_writer.hh"
 #include "gridselector.hh"
 #include "matrices.hh"
+#include "program_state.hh"
 #include "rate.hh"
 #include "sand-wedge-data/mybody.hh"
 #include "sand-wedge-data/mygeometry.hh"
@@ -213,76 +214,17 @@ int main(int argc, char *argv[]) {
           _ell += gravityFunctional;
         };
 
-    // helper
-    auto const solveLinearProblem = [&](
-        Dune::BitSetVector<dims> const &_dirichletNodes, Matrix const &_matrix,
-        Vector const &_rhs, Vector &_x,
-        Dune::ParameterTree const &_localParset) {
-
-      using LinearFactory = SolverFactory<
-          dims, BlockNonlinearTNNMGProblem<ConvexProblem<
-                    ZeroNonlinearity<LocalVector, LocalMatrix>, Matrix>>,
-          Grid>;
-      ZeroNonlinearity<LocalVector, LocalMatrix> zeroNonlinearity;
-      LinearFactory factory(parset.sub("solver.tnnmg"), // FIXME
-                            *grid, _dirichletNodes);
-
-      typename LinearFactory::ConvexProblem convexProblem(
-          1.0, _matrix, zeroNonlinearity, _rhs, _x);
-      typename LinearFactory::BlockProblem problem(parset, convexProblem);
-
-      auto multigridStep = factory.getStep();
-      multigridStep->setProblem(_x, problem);
-      EnergyNorm<Matrix, Vector> const norm(_matrix);
-      LoopSolver<Vector> solver(
-          multigridStep.get(), _localParset.get<size_t>("maximumIterations"),
-          _localParset.get<double>("tolerance"), &norm,
-          _localParset.get<Solver::VerbosityMode>("verbosity"),
-          false); // absolute error
-
-      solver.preprocess();
-      solver.solve();
-    };
-
-    // {{{ Initial conditions
-    Vector ell0(leafVertexCount);
-    computeExternalForces(0.0, ell0);
-
-    // Start from a situation of minimal stress, which is
-    // automatically attained in the case [v = 0 = a]. Assuming
-    // dPhi(v = 0) = 0, we thus only have to solve Au = ell0
-    Vector u_initial(leafVertexCount);
-    solveLinearProblem(dirichletNodes, matrices.elasticity, ell0, u_initial,
-                       parset.sub("u0.solver"));
-
-    ScalarVector normalStress(leafVertexCount);
-    myAssembler.assembleNormalStress(frictionalBoundary, normalStress,
-                                     body.getYoungModulus(),
-                                     body.getPoissonRatio(), u_initial);
-
-    ScalarVector alpha_initial(leafVertexCount);
-    alpha_initial = parset.get<double>("boundary.friction.initialAlpha");
-
-    // Start from zero velocity
-    Vector v_initial(leafVertexCount);
-    v_initial = 0.0;
-
-    // Compute the acceleration from Ma = ell0 - Au (without Dirichlet
-    // constraints), again assuming dPhi(v = 0) = 0
-    Vector a_initial(leafVertexCount);
-    Vector accelerationRHS = ell0;
-    Arithmetic::subtractProduct(accelerationRHS, matrices.elasticity,
-                                u_initial);
-    solveLinearProblem(noNodes, matrices.mass, accelerationRHS, a_initial,
-                       parset.sub("a0.solver"));
-    // }}}
+    ProgramState<Vector, ScalarVector> programState(leafVertexCount);
+    programState.setupInitialConditions(parset, computeExternalForces, matrices,
+                                        myAssembler, dirichletNodes, noNodes,
+                                        frictionalBoundary, body);
 
     MyGlobalFrictionData<LocalVector> frictionInfo(
         parset.sub("boundary.friction"), weakPatch);
     auto myGlobalFriction = myAssembler.assembleFrictionNonlinearity(
         parset.get<Config::FrictionModel>("boundary.friction.frictionModel"),
-        frictionalBoundary, frictionInfo, normalStress);
-    myGlobalFriction->updateAlpha(alpha_initial);
+        frictionalBoundary, frictionInfo, programState.normalStress);
+    myGlobalFriction->updateAlpha(programState.alpha);
 
     Vector vertexCoordinates(leafVertexCount);
     {
@@ -326,7 +268,7 @@ int main(int argc, char *argv[]) {
       specialVelocityWriter.write(velocity);
       specialDisplacementWriter.write(displacement);
     };
-    report(u_initial, v_initial, alpha_initial);
+    report(programState.u, programState.v, programState.alpha);
 
     MyVTKWriter<typename MyAssembler::VertexBasis,
                 typename MyAssembler::CellBasis> const
@@ -334,9 +276,11 @@ int main(int argc, char *argv[]) {
 
     if (parset.get<bool>("io.writeVTK")) {
       ScalarVector stress;
-      myAssembler.assembleVonMisesStress(
-          body.getYoungModulus(), body.getPoissonRatio(), u_initial, stress);
-      vtkWriter.write(0, u_initial, v_initial, alpha_initial, stress);
+      myAssembler.assembleVonMisesStress(body.getYoungModulus(),
+                                         body.getPoissonRatio(), programState.u,
+                                         stress);
+      vtkWriter.write(0, programState.u, programState.v, programState.alpha,
+                      stress);
     }
 
     // Set up TNNMG solver
@@ -352,12 +296,12 @@ int main(int argc, char *argv[]) {
     UpdaterPair current(
         initStateUpdater<ScalarVector, Vector>(
             parset.get<Config::stateModel>("boundary.friction.stateModel"),
-            alpha_initial, frictionalNodes,
+            programState.alpha, frictionalNodes,
             parset.get<double>("boundary.friction.L"),
             parset.get<double>("boundary.friction.V0")),
         initRateUpdater(parset.get<Config::scheme>("timeSteps.scheme"),
                         velocityDirichletFunction, dirichletNodes, matrices,
-                        u_initial, v_initial, a_initial));
+                        programState.u, programState.v, programState.a));
 
     auto const refinementTolerance =
         parset.get<double>("timeSteps.refinementTolerance");
@@ -372,32 +316,31 @@ int main(int argc, char *argv[]) {
       return stateEnergyNorm.diff(fineAlpha, coarseAlpha) > refinementTolerance;
     };
 
-    size_t timeStep = 1;
-
     AdaptiveTimeStepper<NonlinearFactory, UpdaterPair,
                         EnergyNorm<ScalarMatrix, ScalarVector>>
         adaptiveTimeStepper(factory, parset, myGlobalFriction, current,
+                            programState.relativeTime, programState.relativeTau,
                             computeExternalForces, stateEnergyNorm, mustRefine);
     while (!adaptiveTimeStepper.reachedEnd()) {
       adaptiveTimeStepper.advance();
-      reportTimeStep(adaptiveTimeStepper.getRelativeTime(),
-                     adaptiveTimeStepper.getRelativeTau());
-
-      Vector u, v;
-      ScalarVector alpha;
-      current.second->extractDisplacement(u);
-      current.second->extractVelocity(v);
-      current.first->extractAlpha(alpha);
+      programState.relativeTime = adaptiveTimeStepper.getRelativeTime();
+      programState.relativeTau = adaptiveTimeStepper.getRelativeTau();
+      reportTimeStep(programState.relativeTime, programState.relativeTau);
 
-      report(u, v, alpha);
+      current.second->extractDisplacement(programState.u);
+      current.second->extractVelocity(programState.v);
+      current.first->extractAlpha(programState.alpha);
 
+      report(programState.u, programState.v, programState.alpha);
       if (parset.get<bool>("io.writeVTK")) {
         ScalarVector stress;
         myAssembler.assembleVonMisesStress(body.getYoungModulus(),
-                                           body.getPoissonRatio(), u, stress);
-        vtkWriter.write(timeStep, u, v, alpha, stress);
+                                           body.getPoissonRatio(),
+                                           programState.u, stress);
+        vtkWriter.write(programState.timeStep, programState.u, programState.v,
+                        programState.alpha, stress);
       }
-      timeStep++;
+      programState.timeStep++;
     }
     timeStepWriter.close();
     Python::stop();
-- 
GitLab