Skip to content
Snippets Groups Projects
Forked from agnumpde / dune-tectonic
59 commits ahead of the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
solverfactorytest.cc 30.34 KiB
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#define MY_DIM 2

#include <iostream>
#include <fstream>
#include <vector>
#include <exception>

#include <dune/common/exceptions.hh>
#include <dune/common/parallel/mpihelper.hh>
#include <dune/common/stdstreams.hh>
#include <dune/common/fvector.hh>
#include <dune/common/function.hh>
#include <dune/common/bitsetvector.hh>
#include <dune/common/parametertree.hh>
#include <dune/common/parametertreeparser.hh>

#include <dune/fufem/formatstring.hh>

#include <dune/tnnmg/functionals/boxconstrainedquadraticfunctional.hh>
#include <dune/tnnmg/functionals/bcqfconstrainedlinearization.hh>
#include <dune/tnnmg/localsolvers/scalarobstaclesolver.hh>
#include <dune/tnnmg/iterationsteps/tnnmgstep.hh>

#include <dune/tectonic/geocoordinate.hh>

#include "assemblers.hh"
#include "gridselector.hh"
#include "explicitgrid.hh"
#include "explicitvectors.hh"

#include "data-structures/enumparser.hh"
#include "data-structures/enums.hh"
#include "data-structures/contactnetwork.hh"
#include "data-structures/matrices.hh"
#include "data-structures/program_state.hh"

#include "io/vtk.hh"

#include "spatial-solving/tnnmg/tnnmgstep.hh"
#include "spatial-solving/tnnmg/linesearchsolver.hh"
#include "spatial-solving/preconditioners/nbodycontacttransfer.hh"
#include "obstaclesolver.hh"

#include "factories/stackedblocksfactory.hh"

#include "time-stepping/rate.hh"
#include "time-stepping/state.hh"
#include "time-stepping/updaters.hh"

#include "utils/tobool.hh"
#include "utils/debugutils.hh"


const int dim = MY_DIM;

Dune::ParameterTree parset;
size_t bodyCount;

std::vector<BodyState<Vector, ScalarVector>* > bodyStates;
std::vector<Vector> u;
std::vector<Vector> v;
std::vector<Vector> a;
std::vector<ScalarVector> alpha;
std::vector<ScalarVector> weightedNormalStress;
double relativeTime;
double relativeTau;
size_t timeStep = 0;
size_t timeSteps = 0;

const std::string path = "";
const std::string outputFile = "solverfactorytest.log";

std::vector<std::vector<double>> allReductionFactors;

template<class IterationStepType, class NormType, class ReductionFactorContainer>
Dune::Solvers::Criterion reductionFactorCriterion(
      const IterationStepType& iterationStep,
      const NormType& norm,
      ReductionFactorContainer& reductionFactors)
{
  double normOfOldCorrection = 1;
  auto lastIterate = std::make_shared<typename IterationStepType::Vector>(*iterationStep.getIterate());

  return Dune::Solvers::Criterion(
      [&, lastIterate, normOfOldCorrection] () mutable {
        double normOfCorrection = norm.diff(*lastIterate, *iterationStep.getIterate());
        double convRate = (normOfOldCorrection > 0) ? normOfCorrection / normOfOldCorrection : 0.0;

        if (convRate>1.0)
            DUNE_THROW(Dune::Exception, "Solver convergence rate of " + std::to_string(convRate));

        normOfOldCorrection = normOfCorrection;
        *lastIterate = *iterationStep.getIterate();

        reductionFactors.push_back(convRate);
        return std::make_tuple(convRate < 0, Dune::formatString(" % '.5f", convRate));
      },
      " reductionFactor");
}

template <class ContactNetwork>
void solveProblem(const ContactNetwork& contactNetwork,
                  const Matrix& mat, const Vector& rhs, Vector& x,
                  const BitVector& ignore, const Vector& lower, const Vector& upper, bool initial = false) {

    using Solver = typename Dune::Solvers::LoopSolver<Vector, BitVector>;
    using Norm =  EnergyNorm<Matrix, Vector>;

    //using LocalSolver = LocalBisectionSolver;
    //using Linearization = Linearization<Functional, BitVector>;

    /*print(ignore, "ignore:");
    for (size_t i=0; i<x.size(); i++) {
        std::cout << x[i] << std::endl;
    }*/

    // set up reference solver
    Vector refX = x;
    using ContactFunctional = Dune::TNNMG::BoxConstrainedQuadraticFunctional<Matrix&, Vector&, Vector&, Vector&, double>;
    auto I = ContactFunctional(mat, rhs, lower, upper);

    std::cout << "Energy start iterate: " << I(x) << std::endl;

    using LocalSolver = Dune::TNNMG::ScalarObstacleSolver;
    auto localSolver = Dune::TNNMG::gaussSeidelLocalSolver(LocalSolver());

    using NonlinearSmoother = Dune::TNNMG::NonlinearGSStep<ContactFunctional, decltype(localSolver), BitVector>;
    auto nonlinearSmoother = std::make_shared<NonlinearSmoother>(I, refX, localSolver);

    using Linearization = Dune::TNNMG::BoxConstrainedQuadraticFunctionalConstrainedLinearization<ContactFunctional, BitVector>;
    using DefectProjection = Dune::TNNMG::ObstacleDefectProjection;
    using Step = Dune::TNNMG::TNNMGStep<ContactFunctional, BitVector, Linearization, DefectProjection, LocalSolver>;

    // set multigrid solver
    auto smoother = TruncatedBlockGSStep<Matrix, Vector>{};
    using TransferOperator = NBodyContactTransfer<ContactNetwork, Vector>;
    using TransferOperators = std::vector<std::shared_ptr<TransferOperator>>;

    TransferOperators transfer(contactNetwork.nLevels()-1);
    for (size_t i=0; i<transfer.size(); i++) {
        transfer[i] = std::make_shared<TransferOperator>();
        transfer[i]->setup(contactNetwork, i, i+1);
    }

    // Remove any recompute filed so that initially the full transferoperator is assembled
    for (size_t i=0; i<transfer.size(); i++)
        std::dynamic_pointer_cast<TruncatedMGTransfer<Vector> >(transfer[i])->setRecomputeBitField(nullptr);

    auto linearMultigridStep = std::make_shared<Dune::Solvers::MultigridStep<Matrix, Vector> >();
    linearMultigridStep->setMGType(1, 3, 3);
    linearMultigridStep->setSmoother(&smoother);
    linearMultigridStep->setTransferOperators(transfer);

    int mu = parset.get<int>("solver.tnnmg.main.multi"); // #multigrid steps in Newton step
    auto step = Step(I, refX, nonlinearSmoother, linearMultigridStep, mu, DefectProjection(), LocalSolver());

    // compute reference solution with generic functional and solver
    auto norm = Norm(mat);
    auto refSolver = Solver(step, parset.get<size_t>("u0.solver.maximumIterations"),
                            parset.get<double>("u0.solver.tolerance"), norm, Solver::FULL);

    step.setIgnore(ignore);
    step.setPreSmoothingSteps(parset.get<int>("solver.tnnmg.main.pre"));

    refSolver.addCriterion(
            [&](){
            return Dune::formatString("   % 12.5e", I(refX));
            },
            "   energy      ");

    double initialEnergy = I(refX);
    refSolver.addCriterion(
            [&](){
            static double oldEnergy=initialEnergy;
            double currentEnergy = I(refX);
            double decrease = currentEnergy - oldEnergy;
            oldEnergy = currentEnergy;
            return Dune::formatString("   % 12.5e", decrease);
            },
            "   decrease    ");

    refSolver.addCriterion(
            [&](){
            return Dune::formatString("   % 12.5e", step.lastDampingFactor());
            },
            "   damping     ");


    refSolver.addCriterion(
            [&](){
            return Dune::formatString("   % 12d", step.linearization().truncated().count());
            },
            "   truncated   ");

    if (timeStep>0 and initial) {
        allReductionFactors[timeStep].resize(0);
        refSolver.addCriterion(reductionFactorCriterion(step, norm, allReductionFactors[timeStep]));
    }

    refSolver.preprocess();
    refSolver.solve();

    //print(refX, "refX: ");

    if (initial) {
        x = refX;
        return;
    }
    // set up solver factory solver

    // set up functional
    auto& globalFriction = contactNetwork.globalFriction();

  /*  std::vector<const Dune::BitSetVector<1>*> nodes;
     contactNetwork.frictionNodes(nodes);
     print(*nodes[0], "frictionNodes: ");
     print(*nodes[1], "frictionNodes: ");

     print(ignore, "ignore: ");*/

    using MyFunctional = Functional<Matrix&, Vector&, std::decay_t<decltype(globalFriction)>&, Vector&, Vector&, typename Matrix::field_type>;
    MyFunctional J(mat, rhs, globalFriction, lower, upper);
    //using MyFunctional = Functional<Matrix&, Vector&, ZeroNonlinearity, Vector&, Vector&, typename Matrix::field_type>;
    //MyFunctional J(mat, rhs, ZeroNonlinearity(), lower, upper);

    //std::cout << "ref energy: " << J(refX) << std::endl;

    // set up TNMMG solver
    // dummy solver, uses direct solver for linear correction no matter what is set here
    Norm mgNorm(*linearMultigridStep);
    auto mgSolver = std::make_shared<Solver>(linearMultigridStep, parset.get<size_t>("solver.tnnmg.linear.maximumIterations"), parset.get<double>("solver.tnnmg.linear.tolerance"), mgNorm, Solver::QUIET);

    using Factory = SolverFactory<MyFunctional, BitVector>;
    Factory factory(parset.sub("solver.tnnmg"), J, *mgSolver, ignore);

   /* std::vector<BitVector> bodyDirichletNodes;
    nBodyAssembler.postprocess(_dirichletNodes, bodyDirichletNodes);
    for (size_t i=0; i<bodyDirichletNodes.size(); i++) {
      print(bodyDirichletNodes[i], "bodyDirichletNodes_" + std::to_string(i) + ": ");
    }*/

   /* print(bilinearForm, "matrix: ");
    print(totalX, "totalX: ");
    print(totalRhs, "totalRhs: ");*/

    auto tnnmgStep = factory.step();
    factory.setProblem(x);

    LoopSolver<Vector> solver(
        tnnmgStep.get(), parset.get<size_t>("u0.solver.maximumIterations"),
                parset.get<double>("u0.solver.tolerance"), &norm,
        Solver::FULL); //, true, &refX); // absolute error

    solver.addCriterion(
            [&](){
            return Dune::formatString("   % 12.5e", J(x));
            },
            "   energy      ");

    initialEnergy = J(x);
    solver.addCriterion(
            [&](){
            static double oldEnergy=initialEnergy;
            double currentEnergy = J(x);
            double decrease = currentEnergy - oldEnergy;
            oldEnergy = currentEnergy;
            return Dune::formatString("   % 12.5e", decrease);
            },
            "   decrease    ");

    solver.addCriterion(
            [&](){
            return Dune::formatString("   % 12.5e", tnnmgStep->lastDampingFactor());
            },
            "   damping     ");

    solver.addCriterion(
            [&](){
            return Dune::formatString("   % 12d", tnnmgStep->linearization().truncated().count());
            },
            "   truncated   ");


    std::vector<double> factors;
    solver.addCriterion(reductionFactorCriterion(*tnnmgStep, norm, factors));

    solver.preprocess();
    solver.solve();

    auto diff = x;
    diff -= refX;
    std::cout << "Solver error in energy norm: " << norm(diff) << std::endl;

    std::cout << "Energy end iterate: " << J(x) << std::endl;
}


template <class ContactNetwork>
void setupInitialConditions(const ContactNetwork& contactNetwork) {
  using Matrix = typename ContactNetwork::Matrix;
  const auto& nBodyAssembler = contactNetwork.nBodyAssembler();

  std::cout << std::endl << "-- setupInitialConditions --" << std::endl;
  std::cout << "----------------------------" << std::endl;

  // Solving a linear problem with a multigrid solver
  auto const solveLinearProblem = [&](
      const BitVector& _dirichletNodes, const std::vector<std::shared_ptr<Matrix>>& _matrices,
      const std::vector<Vector>& _rhs, std::vector<Vector>& _x) {

    std::vector<const Matrix*> matrices_ptr(_matrices.size());
    for (size_t i=0; i<matrices_ptr.size(); i++) {
          matrices_ptr[i] = _matrices[i].get();
    }

    // assemble full global contact problem
    Matrix bilinearForm;

    nBodyAssembler.assembleJacobian(matrices_ptr, bilinearForm);

    Vector totalRhs, oldTotalRhs;
    nBodyAssembler.assembleRightHandSide(_rhs, totalRhs);
    oldTotalRhs = totalRhs;

    Vector totalX, oldTotalX;
    nBodyAssembler.nodalToTransformed(_x, totalX);
    oldTotalX = totalX;

    // get lower and upper obstacles
    const auto& totalObstacles = nBodyAssembler.totalObstacles_;
    Vector lower(totalObstacles.size());
    Vector upper(totalObstacles.size());

    for (size_t j=0; j<totalObstacles.size(); ++j) {
        const auto& totalObstaclesj = totalObstacles[j];
        auto& lowerj = lower[j];
        auto& upperj = upper[j];
      for (size_t d=0; d<dim; ++d) {
          lowerj[d] = totalObstaclesj[d][0];
          upperj[d] = totalObstaclesj[d][1];
      }
    }

    // print problem
   /* print(bilinearForm, "bilinearForm");
    print(totalRhs, "totalRhs");
    print(_dirichletNodes, "ignore");
    print(totalObstacles, "totalObstacles");
    print(lower, "lower");
    print(upper, "upper");*/

    solveProblem(contactNetwork, bilinearForm, totalRhs, totalX, _dirichletNodes, lower, upper, true);

    nBodyAssembler.postprocess(totalX, _x);
  };

  timeStep = parset.get<size_t>("initialTime.timeStep");
  relativeTime = parset.get<double>("initialTime.relativeTime");
  relativeTau = parset.get<double>("initialTime.relativeTau");

  bodyStates.resize(bodyCount);
  u.resize(bodyCount);
  v.resize(bodyCount);
  a.resize(bodyCount);
  alpha.resize(bodyCount);
  weightedNormalStress.resize(bodyCount);

  for (size_t i=0; i<bodyCount; i++) {
    size_t leafVertexCount =  contactNetwork.body(i)->nVertices();

    u[i].resize(leafVertexCount),
    v[i].resize(leafVertexCount),
    a[i].resize(leafVertexCount),
    alpha[i].resize(leafVertexCount),
    weightedNormalStress[i].resize(leafVertexCount),

    bodyStates[i] = new BodyState<Vector, ScalarVector>(&u[i], &v[i], &a[i], &alpha[i], &weightedNormalStress[i]);
  }

  std::vector<Vector> ell0(bodyCount);
  for (size_t i=0; i<bodyCount; i++) {
    // Initial velocity
    u[i] = 0.0;
    v[i] = 0.0;

    ell0[i].resize(u[i].size());
    ell0[i] = 0.0;

    contactNetwork.body(i)->externalForce()(relativeTime, ell0[i]);
  }

  // Initial displacement: Start from a situation of minimal stress,
  // which is automatically attained in the case [v = 0 = a].
  // Assuming dPhi(v = 0) = 0, we thus only have to solve Au = ell0
  BitVector dirichletNodes;
  contactNetwork.totalNodes("dirichlet", dirichletNodes);
  /*for (size_t i=0; i<dirichletNodes.size(); i++) {
      bool val = false;
      for (size_t d=0; d<dims; d++) {
          val = val || dirichletNodes[i][d];
      }

      dirichletNodes[i] = val;
      for (size_t d=0; d<dims; d++) {
          dirichletNodes[i][d] = val;
      }
  }*/

  std::cout << "solving linear problem for u..." << std::endl;

  solveLinearProblem(dirichletNodes, contactNetwork.matrices().elasticity, ell0, u);

  //print(u, "initial u:");

  // Initial acceleration: Computed in agreement with Ma = ell0 - Au
  // (without Dirichlet constraints), again assuming dPhi(v = 0) = 0
  std::vector<Vector> accelerationRHS = ell0;
  for (size_t i=0; i<bodyCount; i++) {
    // Initial state
    alpha[i] = parset.get<double>("boundary.friction.initialAlpha");

    // Initial normal stress

    const auto& body = contactNetwork.body(i);
    std::vector<std::shared_ptr<typename ContactNetwork::LeafBody::BoundaryCondition>> frictionBoundaryConditions;
    body->boundaryConditions("friction", frictionBoundaryConditions);
    for (size_t j=0; j<frictionBoundaryConditions.size(); j++) {
        ScalarVector frictionBoundaryStress(weightedNormalStress[i].size());

        body->assembler()->assembleWeightedNormalStress(
          *frictionBoundaryConditions[j]->boundaryPatch(), frictionBoundaryStress, body->data()->getYoungModulus(),
          body->data()->getPoissonRatio(), u[i]);

        weightedNormalStress[i] += frictionBoundaryStress;
    }

    Dune::MatrixVector::subtractProduct(accelerationRHS[i], *body->matrices().elasticity, u[i]);
  }

  std::cout << "solving linear problem for a..." << std::endl;

  BitVector noNodes(dirichletNodes.size(), false);
  solveLinearProblem(noNodes, contactNetwork.matrices().mass, accelerationRHS, a);

  //print(a, "initial a:");
}

template <class ContactNetwork>
void relativeVelocities(const ContactNetwork& contactNetwork, const Vector& v, std::vector<Vector>& v_rel) {

    const auto& nBodyAssembler = contactNetwork.nBodyAssembler();
    const size_t nBodies = nBodyAssembler.nGrids();
   // const auto& contactCouplings = nBodyAssembler.getContactCouplings();

    std::vector<const Dune::BitSetVector<1>*> bodywiseNonmortarBoundaries;
    contactNetwork.frictionNodes(bodywiseNonmortarBoundaries);

    size_t globalIdx = 0;
    v_rel.resize(nBodies);
    for (size_t bodyIdx=0; bodyIdx<nBodies; bodyIdx++) {
        const auto& nonmortarBoundary = *bodywiseNonmortarBoundaries[bodyIdx];
        auto& v_rel_ = v_rel[bodyIdx];

        v_rel_.resize(nonmortarBoundary.size());

        for (size_t i=0; i<v_rel_.size(); i++) {
            if (toBool(nonmortarBoundary[i])) {
                v_rel_[i] = v[globalIdx];
            }
            globalIdx++;
        }
    }

}

template <class Updaters, class ContactNetwork>
void run(Updaters& updaters, ContactNetwork& contactNetwork,
    const std::vector<Matrix>& velocityMatrices, const std::vector<Vector>& velocityRHSs,
    std::vector<Vector>& velocityIterates) {

  const auto& nBodyAssembler = contactNetwork.nBodyAssembler();
  const auto nBodies = nBodyAssembler.nGrids();

  std::vector<const Matrix*> matrices_ptr(nBodies);
  for (int i=0; i<nBodies; i++) {
      matrices_ptr[i] = &velocityMatrices[i];
  }

  // assemble full global contact problem
  Matrix bilinearForm;
  nBodyAssembler.assembleJacobian(matrices_ptr, bilinearForm);

  Vector totalRhs;
  nBodyAssembler.assembleRightHandSide(velocityRHSs, totalRhs);

  // get lower and upper obstacles
  const auto& totalObstacles = nBodyAssembler.totalObstacles_;
  Vector lower(totalObstacles.size());
  Vector upper(totalObstacles.size());

  for (size_t j=0; j<totalObstacles.size(); ++j) {
      const auto& totalObstaclesj = totalObstacles[j];
      auto& lowerj = lower[j];
      auto& upperj = upper[j];
    for (size_t d=0; d<dim; ++d) {
        lowerj[d] = totalObstaclesj[d][0];
        upperj[d] = totalObstaclesj[d][1];
    }
  }

  // compute velocity obstacles
  /*Vector vLower, vUpper;
  std::vector<Vector> u0, v0;
  updaters.rate_->extractOldVelocity(v0);
  updaters.rate_->extractOldDisplacement(u0);

  Vector totalu0, totalv0;
  nBodyAssembler.nodalToTransformed(u0, totalu0);
  nBodyAssembler.nodalToTransformed(v0, totalv0);

  updaters.rate_->velocityObstacles(totalu0, lower, totalv0, vLower);
  updaters.rate_->velocityObstacles(totalu0, upper, totalv0, vUpper);*/

  const auto& errorNorms = contactNetwork.stateEnergyNorms();

  auto& globalFriction = contactNetwork.globalFriction();

  BitVector totalDirichletNodes;
  contactNetwork.totalNodes("dirichlet", totalDirichletNodes);

  size_t fixedPointMaxIterations = parset.get<size_t>("v.fpi.maximumIterations");
  double fixedPointTolerance = parset.get<double>("v.fpi.tolerance");
  double lambda = parset.get<double>("v.fpi.lambda");

  size_t fixedPointIteration;
  size_t multigridIterations = 0;

  std::vector<ScalarVector> alpha(nBodies);
  updaters.state_->extractAlpha(alpha);

  Vector totalVelocityIterate;
  nBodyAssembler.nodalToTransformed(velocityIterates, totalVelocityIterate);

  // project in onto admissible set
  const size_t blocksize = Vector::block_type::dimension;
  for (size_t i=0; i<totalVelocityIterate.size(); i++) {
      for (size_t j=0; j<blocksize; j++) {
          if (totalVelocityIterate[i][j] < lower[i][j]) {
              totalVelocityIterate[i][j] = lower[i][j];
          }

          if (totalVelocityIterate[i][j] > upper[i][j]) {
              totalVelocityIterate[i][j] = upper[i][j];
          }
      }
  }

  for (fixedPointIteration = 0; fixedPointIteration < fixedPointMaxIterations;
       ++fixedPointIteration) {

    // contribution from nonlinearity
    //print(alpha, "alpha: ");

    globalFriction.updateAlpha(alpha);

    solveProblem(contactNetwork, bilinearForm, totalRhs, totalVelocityIterate, totalDirichletNodes, lower, upper, false);

    nBodyAssembler.postprocess(totalVelocityIterate, velocityIterates);

    std::vector<Vector> v_m; //TODO : wrong, isnt used atm;
    updaters.rate_->extractOldVelocity(v_m);

    for (size_t i=0; i<v_m.size(); i++) {
      v_m[i] *= 1.0 - lambda;
      Dune::MatrixVector::addProduct(v_m[i], lambda, velocityIterates[i]);
    }

    // extract relative velocities in mortar basis
    std::vector<Vector> v_rel;
    relativeVelocities(contactNetwork, totalVelocityIterate, v_rel);

    //print(v_rel, "v_rel");

    std::cout << "- State problem set" << std::endl;

    // solve a state problem
    updaters.state_->solve(v_rel);

    std::cout << "-- Solved" << std::endl;

    std::vector<ScalarVector> newAlpha(nBodies);
    updaters.state_->extractAlpha(newAlpha);

    bool breakCriterion = true;
    for (int i=0; i<nBodies; i++) {
        if (alpha[i].size()==0 || newAlpha[i].size()==0)
            continue;

        //print(alpha[i], "alpha i:");
        //print(newAlpha[i], "new alpha i:");
        if (errorNorms[i]->diff(alpha[i], newAlpha[i]) >= fixedPointTolerance) {
            breakCriterion = false;
            std::cout << "fixedPoint error: " << errorNorms[i]->diff(alpha[i], newAlpha[i]) << std::endl;
            break;
        }
    }

    if (lambda < 1e-12 or breakCriterion) {
      std::cout << "-FixedPointIteration finished! " << (lambda < 1e-12 ? "lambda" : "breakCriterion") << std::endl;
      fixedPointIteration++;
      break;
    }
    alpha = newAlpha;
  }

  std::cout << "-FixedPointIteration finished! " << std::endl;

  if (fixedPointIteration == fixedPointMaxIterations)
    DUNE_THROW(Dune::Exception, "FPI failed to converge");

  updaters.rate_->postProcess(velocityIterates);
}


template <class Updaters, class ContactNetwork>
void step(Updaters& updaters, ContactNetwork& contactNetwork)  {
  updaters.state_->nextTimeStep();
  updaters.rate_->nextTimeStep();

  auto const newRelativeTime = relativeTime + relativeTau;
  typename ContactNetwork::ExternalForces externalForces;
  contactNetwork.externalForces(externalForces);
  std::vector<Vector> ell(externalForces.size());
  for (size_t i=0; i<externalForces.size(); i++) {
    (*externalForces[i])(newRelativeTime, ell[i]);
  }

  std::vector<Matrix> velocityMatrix;
  std::vector<Vector> velocityRHS;
  std::vector<Vector> velocityIterate;

  double finalTime = parset.get<double>("problem.finalTime");
  auto const tau = relativeTau * finalTime;
  updaters.state_->setup(tau);
  updaters.rate_->setup(ell, tau, newRelativeTime, velocityRHS, velocityIterate, velocityMatrix);

  run(updaters, contactNetwork, velocityMatrix, velocityRHS, velocityIterate);
}

template <class Updaters, class ContactNetwork>
void advanceStep(Updaters& current, ContactNetwork& contactNetwork) {
    step(current, contactNetwork);
    relativeTime += relativeTau;
}

void getParameters(int argc, char *argv[]) {
  Dune::ParameterTreeParser::readINITree("/home/mi/podlesny/software/dune/dune-tectonic/src/multi-body-problem.cfg", parset);
  Dune::ParameterTreeParser::readINITree(
      Dune::Fufem::formatString("/home/mi/podlesny/software/dune/dune-tectonic/src/multi-body-problem-%dD.cfg", dim), parset);
  Dune::ParameterTreeParser::readOptions(argc, argv, parset);
}

int main(int argc, char *argv[]) { try {
    Dune::MPIHelper::instance(argc, argv);

    std::ofstream out(path + outputFile);
    std::streambuf *coutbuf = std::cout.rdbuf(); //save old buffer
    std::cout.rdbuf(out.rdbuf()); //redirect std::cout to outputFile

    std::cout << "-------------------------" << std::endl;
    std::cout << "-- SolverFactory Test: --" << std::endl;
    std::cout << "-------------------------" << std::endl << std::endl;

    getParameters(argc, argv);

    // ----------------------
    // set up contact network
    // ----------------------
    StackedBlocksFactory<Grid, Vector> stackedBlocksFactory(parset);
    using ContactNetwork = typename StackedBlocksFactory<Grid, Vector>::ContactNetwork;
    stackedBlocksFactory.build();

    ContactNetwork& contactNetwork = stackedBlocksFactory.contactNetwork();
    bodyCount = contactNetwork.nBodies();

    for (size_t i=0; i<contactNetwork.nLevels(); i++) {
        const auto& level = *contactNetwork.level(i);

        for (size_t j=0; j<level.nBodies(); j++) {
            writeToVTK(level.body(j)->gridView(), "debug_print/bodies/", "body_" + std::to_string(j) + "_level_" + std::to_string(i));
        }
    }

    for (size_t i=0; i<bodyCount; i++) {
        writeToVTK(contactNetwork.body(i)->gridView(), "debug_print/bodies/", "body_" + std::to_string(i) + "_leaf");
    }

    // ----------------------------
    // assemble contactNetwork
    // ----------------------------
    contactNetwork.assemble();

    //printMortarBasis<Vector>(contactNetwork.nBodyAssembler());

    // -----------------
    // init input/output
    // -----------------
    timeSteps = parset.get<size_t>("timeSteps.timeSteps");
    allReductionFactors.resize(timeSteps+1);

    setupInitialConditions(contactNetwork);

    auto& nBodyAssembler = contactNetwork.nBodyAssembler();
    for (size_t i=0; i<bodyCount; i++) {
      contactNetwork.body(i)->setDeformation(u[i]);
    }
    nBodyAssembler.assembleTransferOperator();
    nBodyAssembler.assembleObstacle();

    // ------------------------
    // assemble global friction
    // ------------------------
    contactNetwork.assembleFriction(parset.get<Config::FrictionModel>("boundary.friction.frictionModel"), weightedNormalStress);

    auto& globalFriction = contactNetwork.globalFriction();
    globalFriction.updateAlpha(alpha);

    using Assembler = MyAssembler<DefLeafGridView, dim>;
    using field_type = Matrix::field_type;
    using MyVertexBasis = typename Assembler::VertexBasis;
    using MyCellBasis = typename Assembler::CellBasis;
    std::vector<Vector> vertexCoordinates(bodyCount);
    std::vector<const MyVertexBasis* > vertexBases(bodyCount);
    std::vector<const MyCellBasis* > cellBases(bodyCount);

    for (size_t i=0; i<bodyCount; i++) {
      const auto& body = contactNetwork.body(i);
      vertexBases[i] = &(body->assembler()->vertexBasis);
      cellBases[i] = &(body->assembler()->cellBasis);

      auto& vertexCoords = vertexCoordinates[i];
      vertexCoords.resize(body->nVertices());

      Dune::MultipleCodimMultipleGeomTypeMapper<
          DefLeafGridView, Dune::MCMGVertexLayout> const vertexMapper(body->gridView(), Dune::mcmgVertexLayout());
      for (auto &&v : vertices(body->gridView()))
        vertexCoords[vertexMapper.index(v)] = geoToPoint(v.geometry());
    }

    const MyVTKWriter<MyVertexBasis, MyCellBasis> vtkWriter(cellBases, vertexBases, "/storage/mi/podlesny/software/dune/dune-tectonic/body");

    auto const report = [&](bool initial = false) {
      if (parset.get<bool>("io.printProgress"))
        std::cout << "timeStep = " << std::setw(6) << timeStep
                  << ", time = " << std::setw(12) << relativeTime
                  << ", tau = " << std::setw(12) << relativeTau
                  << std::endl;

      if (parset.get<bool>("io.vtk.write")) {
        std::vector<ScalarVector> stress(bodyCount);

        for (size_t i=0; i<bodyCount; i++) {
          const auto& body = contactNetwork.body(i);
          body->assembler()->assembleVonMisesStress(body->data()->getYoungModulus(),
                                           body->data()->getPoissonRatio(),
                                           u[i], stress[i]);

        }

        vtkWriter.write(timeStep, u, v, alpha, stress);
      }
    };
    report(true);

    // -------------------
    // Set up TNNMG solver
    // -------------------

    /*BitVector totalDirichletNodes;
    contactNetwork.totalNodes("dirichlet", totalDirichletNodes);

    print(totalDirichletNodes, "totalDirichletNodes:");*/

    //using Functional = Functional<Matrix&, Vector&, ZeroNonlinearity&, Vector&, Vector&, field_type>;
    //using Functional = Functional<Matrix&, Vector&, GlobalFriction<Matrix, Vector>&, Vector&, Vector&, field_type>;
    //using NonlinearFactory = SolverFactory<Functional, BitVector>;

    using BoundaryFunctions = typename ContactNetwork::BoundaryFunctions;
    using BoundaryNodes = typename ContactNetwork::BoundaryNodes;
    using Updaters = Updaters<RateUpdater<Vector, Matrix, BoundaryFunctions, BoundaryNodes>,
                               StateUpdater<ScalarVector, Vector>>;

    BoundaryFunctions velocityDirichletFunctions;
    contactNetwork.boundaryFunctions("dirichlet", velocityDirichletFunctions);

    BoundaryNodes dirichletNodes;
    contactNetwork.boundaryNodes("dirichlet", dirichletNodes);

    /*for (size_t i=0; i<dirichletNodes.size(); i++) {
        for (size_t j=0; j<dirichletNodes[i].size(); j++) {
        print(*dirichletNodes[i][j], "dirichletNodes_body_" + std::to_string(i) + "_boundary_" + std::to_string(j));
        }
    }*/


    /*for (size_t i=0; i<frictionNodes.size(); i++) {
        print(*frictionNodes[i], "frictionNodes_body_" + std::to_string(i));
    }*/

    Updaters current(
        initRateUpdater(
            parset.get<Config::scheme>("timeSteps.scheme"),
            velocityDirichletFunctions,
            dirichletNodes,
            contactNetwork.matrices(),
            u,
            v,
            a),
        initStateUpdater<ScalarVector, Vector>(
            parset.get<Config::stateModel>("boundary.friction.stateModel"),
            alpha,
            nBodyAssembler.getContactCouplings(),
            contactNetwork.couplings())
            );



    //const auto& stateEnergyNorms = contactNetwork.stateEnergyNorms();

    for (timeStep=1; timeStep<=timeSteps; timeStep++) {

      advanceStep(current, contactNetwork);

      relativeTime += relativeTau;
      current.rate_->extractDisplacement(u);
      current.rate_->extractVelocity(v);
      current.rate_->extractAcceleration(a);
      current.state_->extractAlpha(alpha);

      contactNetwork.setDeformation(u);

      report();
    }

    // output reduction factors
    size_t count = 0;
    for (size_t i=0; i<allReductionFactors.size(); i++) {
        count = std::max(count, allReductionFactors[i].size());
    }
    std::vector<double> avgReductionFactors(count);
    for (size_t i=0; i<count; i++) {
        avgReductionFactors[i] = 1;
        size_t c = 0;

        for (size_t j=0; j<allReductionFactors.size(); j++) {
            if (!(i<allReductionFactors[j].size()))
                continue;

            avgReductionFactors[i] *= allReductionFactors[j][i];
            c++;
        }

        avgReductionFactors[i] = std::pow(avgReductionFactors[i], 1.0/((double)c));
    }

    print(avgReductionFactors, "average reduction factors: ");

    bool passed = true;

    std::cout << "Overall, the test " << (passed ? "was successful!" : "failed!") << std::endl;

    std::cout.rdbuf(coutbuf); //reset to standard output again
    return passed ? 0 : 1;

} catch (Dune::Exception &e) {
    Dune::derr << "Dune reported error: " << e << std::endl;
} catch (std::exception &e) {
    std::cerr << "Standard exception: " << e.what() << std::endl;
} // end try
} // end main