From 1a3a36cf9056abf97e8839004c2ff973cd413165 Mon Sep 17 00:00:00 2001
From: Elias Pipping <elias.pipping@fu-berlin.de>
Date: Sat, 19 Jul 2014 20:19:50 +0200
Subject: [PATCH] [Cleanup] Modularise FixedPointIterator

---
 src/Makefile.am                |  1 +
 src/fixedpointiterator.cc      | 80 +++++++++++++++++++++++++++++++
 src/fixedpointiterator.hh      | 41 ++++++++++++++++
 src/fixedpointiterator_tmpl.cc | 26 ++++++++++
 src/sand-wedge.cc              | 87 +---------------------------------
 5 files changed, 149 insertions(+), 86 deletions(-)
 create mode 100644 src/fixedpointiterator.cc
 create mode 100644 src/fixedpointiterator.hh
 create mode 100644 src/fixedpointiterator_tmpl.cc

diff --git a/src/Makefile.am b/src/Makefile.am
index 5908465c..23d83a43 100644
--- a/src/Makefile.am
+++ b/src/Makefile.am
@@ -4,6 +4,7 @@ common_sources = \
 	assemblers.cc \
 	boundary_writer.cc \
 	enumparser.cc \
+	fixedpointiterator.cc \
 	friction_writer.cc \
 	solverfactory.cc \
 	state.cc \
diff --git a/src/fixedpointiterator.cc b/src/fixedpointiterator.cc
new file mode 100644
index 00000000..c013d991
--- /dev/null
+++ b/src/fixedpointiterator.cc
@@ -0,0 +1,80 @@
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include <dune/common/exceptions.hh>
+
+#include <dune/solvers/common/arithmetic.hh>
+#include <dune/solvers/solvers/loopsolver.hh>
+
+#include "enums.hh"
+#include "enumparser.hh"
+
+#include "fixedpointiterator.hh"
+
+template <class Factory, class StateUpdater, class VelocityUpdater>
+FixedPointIterator<Factory, StateUpdater, VelocityUpdater>::FixedPointIterator(
+    Factory &factory, Dune::ParameterTree const &parset,
+    std::shared_ptr<Nonlinearity> globalFriction)
+    : factory_(factory),
+      parset_(parset),
+      globalFriction_(globalFriction),
+      fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")),
+      fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")),
+      lambda_(parset.get<double>("v.fpi.lambda")),
+      velocityMaxIterations_(parset.get<size_t>("v.solver.maximumIterations")),
+      velocityTolerance_(parset.get<double>("v.solver.tolerance")),
+      verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")) {}
+
+template <class Factory, class StateUpdater, class VelocityUpdater>
+int FixedPointIterator<Factory, StateUpdater, VelocityUpdater>::run(
+    std::shared_ptr<StateUpdater> stateUpdater,
+    std::shared_ptr<VelocityUpdater> velocityUpdater,
+    Matrix const &velocityMatrix, Norm<Vector> const &velocityMatrixNorm,
+    Vector const &velocityRHS, Vector &velocityIterate) {
+  auto multigridStep = factory_.getSolver();
+
+  LoopSolver<Vector> velocityProblemSolver(
+      multigridStep, velocityMaxIterations_, velocityTolerance_,
+      &velocityMatrixNorm, verbosity_, false); // absolute error
+
+  Vector previousVelocityIterate = velocityIterate;
+
+  size_t fixedPointIteration;
+  for (fixedPointIteration = 1; fixedPointIteration <= fixedPointMaxIterations_;
+       ++fixedPointIteration) {
+    Vector v_m;
+    velocityUpdater->extractOldVelocity(v_m);
+    v_m *= 1.0 - lambda_;
+    Arithmetic::addProduct(v_m, lambda_, velocityIterate);
+
+    // solve a state problem
+    stateUpdater->solve(v_m);
+    ScalarVector alpha;
+    stateUpdater->extractAlpha(alpha);
+
+    // solve a velocity problem
+    globalFriction_->updateAlpha(alpha);
+    ConvexProblem convexProblem(1.0, velocityMatrix, *globalFriction_,
+                                velocityRHS, velocityIterate);
+    BlockProblem velocityProblem(parset_, convexProblem);
+    multigridStep->setProblem(velocityIterate, velocityProblem);
+    velocityProblemSolver.preprocess();
+    velocityProblemSolver.solve();
+
+    if (velocityMatrixNorm.diff(previousVelocityIterate, velocityIterate) <
+        fixedPointTolerance_)
+      break;
+
+    previousVelocityIterate = velocityIterate;
+  }
+  if (fixedPointIteration == fixedPointMaxIterations_)
+    DUNE_THROW(Dune::Exception, "FPI failed to converge");
+
+  velocityUpdater->postProcess(velocityIterate);
+  velocityUpdater->postProcessRelativeQuantities();
+
+  return fixedPointIteration;
+}
+
+#include "fixedpointiterator_tmpl.cc"
diff --git a/src/fixedpointiterator.hh b/src/fixedpointiterator.hh
new file mode 100644
index 00000000..70a03c97
--- /dev/null
+++ b/src/fixedpointiterator.hh
@@ -0,0 +1,41 @@
+#ifndef SRC_FIXEDPOINTITERATOR_HH
+#define SRC_FIXEDPOINTITERATOR_HH
+
+#include <memory>
+
+#include <dune/common/parametertree.hh>
+
+#include <dune/solvers/norms/norm.hh>
+#include <dune/solvers/solvers/solver.hh>
+
+template <class Factory, class StateUpdater, class VelocityUpdater>
+class FixedPointIterator {
+  using ScalarVector = typename StateUpdater::ScalarVector;
+  using Vector = typename Factory::Vector;
+  using Matrix = typename Factory::Matrix;
+  using ConvexProblem = typename Factory::ConvexProblem;
+  using BlockProblem = typename Factory::BlockProblem;
+  using Nonlinearity = typename ConvexProblem::NonlinearityType;
+
+public:
+  FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset,
+                     std::shared_ptr<Nonlinearity> globalFriction);
+
+  int run(std::shared_ptr<StateUpdater> stateUpdater,
+          std::shared_ptr<VelocityUpdater> velocityUpdater,
+          Matrix const &velocityMatrix, Norm<Vector> const &velocityMatrixNorm,
+          Vector const &velocityRHS, Vector &velocityIterate);
+
+private:
+  Factory &factory_;
+  Dune::ParameterTree const &parset_;
+  std::shared_ptr<Nonlinearity> globalFriction_;
+
+  size_t fixedPointMaxIterations_;
+  double fixedPointTolerance_;
+  double lambda_;
+  size_t velocityMaxIterations_;
+  double velocityTolerance_;
+  Solver::VerbosityMode verbosity_;
+};
+#endif
diff --git a/src/fixedpointiterator_tmpl.cc b/src/fixedpointiterator_tmpl.cc
new file mode 100644
index 00000000..22631e8e
--- /dev/null
+++ b/src/fixedpointiterator_tmpl.cc
@@ -0,0 +1,26 @@
+#ifndef DIM
+#error DIM unset
+#endif
+
+#include <dune/common/function.hh>
+
+#include <dune/tnnmg/problem-classes/convexproblem.hh>
+
+#include <dune/tectonic/globalfriction.hh>
+#include <dune/tectonic/myblockproblem.hh>
+
+#include "explicitgrid.hh"
+#include "explicitvectors.hh"
+
+#include "solverfactory.hh"
+#include "state/stateupdater.hh"
+#include "timestepping.hh"
+
+using Function = Dune::VirtualFunction<double, double>;
+using Factory = SolverFactory<
+    DIM, MyBlockProblem<ConvexProblem<GlobalFriction<Matrix, Vector>, Matrix>>,
+    Grid>;
+using MyStateUpdater = StateUpdater<ScalarVector, Vector>;
+using MyVelocityUpdater = TimeSteppingScheme<Vector, Matrix, Function, DIM>;
+
+template class FixedPointIterator<Factory, MyStateUpdater, MyVelocityUpdater>;
diff --git a/src/sand-wedge.cc b/src/sand-wedge.cc
index d94813e8..1f35a207 100644
--- a/src/sand-wedge.cc
+++ b/src/sand-wedge.cc
@@ -70,6 +70,7 @@
 #include "tobool.hh"
 #include "enumparser.hh"
 #include "enums.hh"
+#include "fixedpointiterator.hh"
 #include "friction_writer.hh"
 #include "sand-wedge-data/mybody.hh"
 #include "sand-wedge-data/mygeometry.hh"
@@ -92,92 +93,6 @@ void initPython() {
   Python::run("sys.path.append('" datadir "')");
 }
 
-template <class Factory, class StateUpdater, class VelocityUpdater>
-class FixedPointIterator {
-  using ScalarVector = typename StateUpdater::ScalarVector;
-  using Vector = typename Factory::Vector;
-  using Matrix = typename Factory::Matrix;
-  using ConvexProblem = typename Factory::ConvexProblem;
-  using BlockProblem = typename Factory::BlockProblem;
-  using Nonlinearity = typename ConvexProblem::NonlinearityType;
-
-public:
-  FixedPointIterator(Factory &factory, Dune::ParameterTree const &parset,
-                     std::shared_ptr<Nonlinearity> globalFriction)
-      : factory_(factory),
-        parset_(parset),
-        globalFriction_(globalFriction),
-        fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")),
-        fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")),
-        lambda_(parset.get<double>("v.fpi.lambda")),
-        velocityMaxIterations_(
-            parset.get<size_t>("v.solver.maximumIterations")),
-        velocityTolerance_(parset.get<double>("v.solver.tolerance")),
-        verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")) {}
-
-  int run(std::shared_ptr<StateUpdater> stateUpdater,
-          std::shared_ptr<VelocityUpdater> velocityUpdater,
-          Matrix const &velocityMatrix, Norm<Vector> const &velocityMatrixNorm,
-          Vector const &velocityRHS, Vector &velocityIterate) {
-    auto multigridStep = factory_.getSolver();
-
-    LoopSolver<Vector> velocityProblemSolver(
-        multigridStep, velocityMaxIterations_, velocityTolerance_,
-        &velocityMatrixNorm, verbosity_, false); // absolute error
-
-    Vector previousVelocityIterate = velocityIterate;
-
-    size_t fixedPointIteration;
-    for (fixedPointIteration = 1;
-         fixedPointIteration <= fixedPointMaxIterations_;
-         ++fixedPointIteration) {
-      Vector v_m;
-      velocityUpdater->extractOldVelocity(v_m);
-      v_m *= 1.0 - lambda_;
-      Arithmetic::addProduct(v_m, lambda_, velocityIterate);
-
-      // solve a state problem
-      stateUpdater->solve(v_m);
-      ScalarVector alpha;
-      stateUpdater->extractAlpha(alpha);
-
-      // solve a velocity problem
-      globalFriction_->updateAlpha(alpha);
-      ConvexProblem convexProblem(1.0, velocityMatrix, *globalFriction_,
-                                  velocityRHS, velocityIterate);
-      BlockProblem velocityProblem(parset_, convexProblem);
-      multigridStep->setProblem(velocityIterate, velocityProblem);
-      velocityProblemSolver.preprocess();
-      velocityProblemSolver.solve();
-
-      if (velocityMatrixNorm.diff(previousVelocityIterate, velocityIterate) <
-          fixedPointTolerance_)
-        break;
-
-      previousVelocityIterate = velocityIterate;
-    }
-    if (fixedPointIteration == fixedPointMaxIterations_)
-      DUNE_THROW(Dune::Exception, "FPI failed to converge");
-
-    velocityUpdater->postProcess(velocityIterate);
-    velocityUpdater->postProcessRelativeQuantities();
-
-    return fixedPointIteration;
-  }
-
-private:
-  Factory &factory_;
-  Dune::ParameterTree const &parset_;
-  std::shared_ptr<Nonlinearity> globalFriction_;
-
-  size_t fixedPointMaxIterations_;
-  double fixedPointTolerance_;
-  double lambda_;
-  size_t velocityMaxIterations_;
-  double velocityTolerance_;
-  Solver::VerbosityMode verbosity_;
-};
-
 template <class Factory, class StateUpdater, class VelocityUpdater>
 class CoupledTimeStepper {
   using Vector = typename Factory::Vector;
-- 
GitLab