#ifndef SRC_PROGRAM_STATE_HH
#define SRC_PROGRAM_STATE_HH

#include <dune/common/parametertree.hh>

#include <dune/matrix-vector/axpy.hh>

#include <dune/fufem/boundarypatch.hh>
#include <dune/tnnmg/problem-classes/blocknonlineartnnmgproblem.hh>

#include <dune/contact/assemblers/nbodyassembler.hh>

#include <dune/solvers/norms/energynorm.hh>
#include <dune/solvers/solvers/loopsolver.hh>
#include <dune/solvers/iterationsteps/cgstep.hh>

#include <dune/tectonic/bodydata.hh>

#include "../assemblers.hh"
#include "contactnetwork.hh"
#include "matrices.hh"
#include "../spatial-solving/preconditioners/multilevelpatchpreconditioner.hh"
#include "../spatial-solving/solverfactory.hh"
#include "../spatial-solving/tnnmg/functional.hh"
#include "../spatial-solving/tnnmg/zerononlinearity.hh"
#include "../utils/debugutils.hh"

template <class VectorTEMPLATE, class ScalarVectorTEMPLATE> class BodyState {
public:
  using Vector = VectorTEMPLATE;
  using ScalarVector = ScalarVectorTEMPLATE;

  BodyState(Vector * _u, Vector * _v, Vector * _a, ScalarVector * _alpha, ScalarVector * _weightedNormalStress)
    : u(_u),
      v(_v),
      a(_a),
      alpha(_alpha),
      weightedNormalStress(_weightedNormalStress) {}

public:
  Vector * const u;
  Vector * const v;
  Vector * const a;
  ScalarVector * const alpha;
  ScalarVector * const weightedNormalStress;
};


template <class VectorTEMPLATE, class ScalarVectorTEMPLATE> class ProgramState {
public:
  using Vector = VectorTEMPLATE;
  using ScalarVector = ScalarVectorTEMPLATE;
  using BodyState = BodyState<Vector, ScalarVector>;

private:
    using LocalVector = typename Vector::block_type;
    //using LocalMatrix = typename Matrix::block_type;
    const static int dims = LocalVector::dimension;
    using BitVector = Dune::BitSetVector<dims>;

public:
  ProgramState(const std::vector<size_t>& leafVertexCounts)
    : bodyCount_(leafVertexCounts.size()),
      bodyStates(bodyCount_),
      u(bodyCount_),
      v(bodyCount_),
      a(bodyCount_),
      alpha(bodyCount_),
      weightedNormalStress(bodyCount_) {
    for (size_t i=0; i<bodyCount_; i++) {
      size_t leafVertexCount = leafVertexCounts[i];

      u[i].resize(leafVertexCount),
      v[i].resize(leafVertexCount),
      a[i].resize(leafVertexCount),
      alpha[i].resize(leafVertexCount),
      weightedNormalStress[i].resize(leafVertexCount),

      bodyStates[i] = new BodyState(&u[i], &v[i], &a[i], &alpha[i], &weightedNormalStress[i]);
    }
  }

  ~ProgramState() {
    for (size_t i=0; i<bodyCount_; i++) {
      delete bodyStates[i];
    }
  }

  size_t size() const {
      return bodyCount_;
  }

  // Set up initial conditions
  template <class ContactNetwork>
  void setupInitialConditions(const Dune::ParameterTree& parset, const ContactNetwork& contactNetwork) {
    using Matrix = typename ContactNetwork::Matrix;
    const auto& nBodyAssembler = contactNetwork.nBodyAssembler();

    const auto& preconditionerParset = parset.sub("solver.tnnmg.linear.preconditioner");

    using Preconditioner = MultilevelPatchPreconditioner<ContactNetwork, Matrix, Vector>;
    Dune::BitSetVector<1> activeLevels(contactNetwork.nLevels(), true);
    Preconditioner preconditioner(preconditionerParset, contactNetwork, activeLevels);
    preconditioner.build();

    using LinearSolver = typename Dune::Solvers::LoopSolver<Vector, BitVector>;
    using LinearSolverStep = typename Dune::Solvers::CGStep<Matrix, Vector, BitVector>;

    LinearSolverStep linearSolverStep;
    linearSolverStep.setPreconditioner(preconditioner);

    EnergyNorm<Matrix, Vector> energyNorm(linearSolverStep);
    LinearSolver linearSolver(linearSolverStep, parset.get<size_t>("solver.tnnmg.linear.maximumIterations"), parset.get<double>("solver.tnnmg.linear.tolerance"), energyNorm, Solver::QUIET);

    // Solving a linear problem with a multigrid solver
    auto const solveLinearProblem = [&](
        const BitVector& _dirichletNodes, const std::vector<std::shared_ptr<Matrix>>& _matrices,
        const std::vector<Vector>& _rhs, std::vector<Vector>& _x,
        Dune::ParameterTree const &_localParset) {

      std::vector<const Matrix*> matrices_ptr(_matrices.size());
      for (size_t i=0; i<matrices_ptr.size(); i++) {
            matrices_ptr[i] = _matrices[i].get();
      }

      /*std::vector<Matrix> matrices(velocityMatrices.size());
          std::vector<Vector> rhs(velocityRHSs.size());
          for (size_t i=0; i<globalFriction_.size(); i++) {
            matrices[i] = velocityMatrices[i];
            rhs[i] = velocityRHSs[i];

            globalFriction_[i]->addHessian(v_rel[i], matrices[i]);
            globalFriction_[i]->addGradient(v_rel[i], rhs[i]);

            matrices_ptr[i] = &matrices[i];
          }*/

      // assemble full global contact problem
      Matrix bilinearForm;

      nBodyAssembler.assembleJacobian(matrices_ptr, bilinearForm);

      Vector totalRhs, oldTotalRhs;
      nBodyAssembler.assembleRightHandSide(_rhs, totalRhs);
      oldTotalRhs = totalRhs;

      Vector totalX, oldTotalX;
      nBodyAssembler.nodalToTransformed(_x, totalX);
      oldTotalX = totalX;

      // get lower and upper obstacles
      const auto& totalObstacles = nBodyAssembler.totalObstacles_;
      Vector lower(totalObstacles.size());
      Vector upper(totalObstacles.size());

      for (size_t j=0; j<totalObstacles.size(); ++j) {
          const auto& totalObstaclesj = totalObstacles[j];
          auto& lowerj = lower[j];
          auto& upperj = upper[j];
        for (size_t d=0; d<dims; ++d) {
            lowerj[d] = totalObstaclesj[d][0];
            upperj[d] = totalObstaclesj[d][1];
        }
      }

      // print problem
     /* print(bilinearForm, "bilinearForm");
      print(totalRhs, "totalRhs");
      print(_dirichletNodes, "ignore");
      print(totalObstacles, "totalObstacles");
      print(lower, "lower");
      print(upper, "upper");*/

      // set up functional
      using Functional = Functional<Matrix&, Vector&, ZeroNonlinearity&, Vector&, Vector&, typename Matrix::field_type>;
      Functional J(bilinearForm, totalRhs, ZeroNonlinearity(), lower, upper); //TODO

      // set up TNMMG solver
      using Factory = SolverFactory<Functional, BitVector>;
      Factory factory(parset.sub("solver.tnnmg"), J, linearSolver, _dirichletNodes);

     /* std::vector<BitVector> bodyDirichletNodes;
      nBodyAssembler.postprocess(_dirichletNodes, bodyDirichletNodes);
      for (size_t i=0; i<bodyDirichletNodes.size(); i++) {
        print(bodyDirichletNodes[i], "bodyDirichletNodes_" + std::to_string(i) + ": ");
      }*/

     /* print(bilinearForm, "matrix: ");
      print(totalX, "totalX: ");
      print(totalRhs, "totalRhs: ");*/

      auto tnnmgStep = factory.step();
      factory.setProblem(totalX);

      const EnergyNorm<Matrix, Vector> norm(bilinearForm);

      LoopSolver<Vector> solver(
          tnnmgStep.get(), _localParset.get<size_t>("maximumIterations"),
          _localParset.get<double>("tolerance"), &norm,
          _localParset.get<Solver::VerbosityMode>("verbosity"),
          false); // absolute error

      solver.preprocess();
      solver.solve();

      nBodyAssembler.postprocess(tnnmgStep->getSol(), _x);

      Vector res = totalRhs;
      bilinearForm.mmv(tnnmgStep->getSol(), res);
      std::cout << "TNNMG Res - energy norm: " << norm.operator ()(res) << std::endl;

      // TODO: remove after debugging
  /*    using DeformedGridType = typename LevelContactNetwork<GridType, dims>::DeformedGridType;
      using OldLinearFactory = SolverFactoryOld<DeformedGridType, GlobalFriction<Matrix, Vector>, Matrix, Vector>;
      OldLinearFactory oldFactory(parset.sub("solver.tnnmg"), nBodyAssembler, _dirichletNodes);

      auto oldStep = oldFactory.getStep();
      oldStep->setProblem(bilinearForm, oldTotalX, oldTotalRhs);

          LoopSolver<Vector> oldSolver(
              oldStep.get(), _localParset.get<size_t>("maximumIterations"),
              _localParset.get<double>("tolerance"), &norm,
              _localParset.get<Solver::VerbosityMode>("verbosity"),
              false); // absolute error

          oldSolver.preprocess();
          oldSolver.solve();

      Vector oldRes = totalRhs;
      bilinearForm.mmv(oldStep->getSol(), oldRes);
      std::cout << "Old Res - energy norm: " << norm.operator ()(oldRes) << std::endl;*/

   //   print(tnnmgStep->getSol(), "TNNMG Solution: ");
   /*   print(oldStep->getSol(), "Old Solution: ");
      auto diff = tnnmgStep->getSol();
      diff -= oldStep->getSol();
      std::cout << "Energy norm: " << norm.operator ()(diff) << std::endl;*/
    };

    timeStep = parset.get<size_t>("initialTime.timeStep");
    relativeTime = parset.get<double>("initialTime.relativeTime");
    relativeTau = parset.get<double>("initialTime.relativeTau");

    std::vector<Vector> ell0(bodyCount_);
    for (size_t i=0; i<bodyCount_; i++) {
      // Initial velocity
      u[i] = 0.0;
      v[i] = 0.0;

      ell0[i].resize(u[i].size());
      ell0[i] = 0.0;

      contactNetwork.body(i)->externalForce()(relativeTime, ell0[i]);
    }

    // Initial displacement: Start from a situation of minimal stress,
    // which is automatically attained in the case [v = 0 = a].
    // Assuming dPhi(v = 0) = 0, we thus only have to solve Au = ell0
    BitVector dirichletNodes;
    contactNetwork.totalNodes("dirichlet", dirichletNodes);
    /*for (size_t i=0; i<dirichletNodes.size(); i++) {
        bool val = false;
        for (size_t d=0; d<dims; d++) {
            val = val || dirichletNodes[i][d];
        }

        dirichletNodes[i] = val;
        for (size_t d=0; d<dims; d++) {
            dirichletNodes[i][d] = val;
        }
    }*/

    solveLinearProblem(dirichletNodes, contactNetwork.matrices().elasticity, ell0, u,
                       parset.sub("u0.solver"));

    print(u, "initial u:");

    // Initial acceleration: Computed in agreement with Ma = ell0 - Au
    // (without Dirichlet constraints), again assuming dPhi(v = 0) = 0
    std::vector<Vector> accelerationRHS = ell0;
    for (size_t i=0; i<bodyCount_; i++) {
      // Initial state
      alpha[i] = parset.get<double>("boundary.friction.initialAlpha");

      // Initial normal stress

      const auto& body = contactNetwork.body(i);
      std::vector<std::shared_ptr<typename ContactNetwork::LeafBody::BoundaryCondition>> frictionBoundaryConditions;
      body->boundaryConditions("friction", frictionBoundaryConditions);
      for (size_t j=0; j<frictionBoundaryConditions.size(); j++) {
          ScalarVector frictionBoundaryStress(weightedNormalStress[i].size());

          body->assembler()->assembleWeightedNormalStress(
            *frictionBoundaryConditions[j]->boundaryPatch(), frictionBoundaryStress, body->data()->getYoungModulus(),
            body->data()->getPoissonRatio(), u[i]);

          weightedNormalStress[i] += frictionBoundaryStress;
      }

      Dune::MatrixVector::subtractProduct(accelerationRHS[i], *body->matrices().elasticity, u[i]);
    }

    BitVector noNodes(dirichletNodes.size(), false);
    solveLinearProblem(noNodes, contactNetwork.matrices().mass, accelerationRHS, a,
                       parset.sub("a0.solver"));

    print(a, "initial a:");
  }

private:
  const size_t bodyCount_;

public:
  std::vector<BodyState* > bodyStates;
  std::vector<Vector> u;
  std::vector<Vector> v;
  std::vector<Vector> a;
  std::vector<ScalarVector> alpha;
  std::vector<ScalarVector> weightedNormalStress;
  double relativeTime;
  double relativeTau;
  size_t timeStep;


};
#endif