#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#ifdef HAVE_IPOPT
#undef HAVE_IPOPT
#endif

#include <dune/common/bitsetvector.hh>

#include <dune/fufem/assemblers/transferoperatorassembler.hh>
#include <dune/solvers/solvers/solver.hh>

#include "solverfactory.hh"

template <class DeformedGridTEMPLATE, class MatrixType, class VectorType>
SolverFactory<DeformedGrid, Matrix, Vector>::SolverFactory(
    Dune::ParameterTree const &parset, const Dune::Contact::NBodyAssembler<DeformedGrid, Vector>& nBodyAssembler,
    Dune::BitSetVector<dim> const &ignoreNodes)
    : nBodyAssembler_(nBodyAssembler),
      baseEnergyNorm_(baseSolverStep_),
      baseSolver_(&baseSolverStep_,
                       parset.get<size_t>("linear.maxiumumIterations"),
                       parset.get<double>("linear.tolerance"), &baseEnergyNorm_,
                       Solver::QUIET),
      transferOperators_(nBodyAssembler.getGrids().at(0)->maxLevel()) {

  // tnnmg iteration step
  multigridStep_->setMGType(parset.get<int>("main.multi"), parset.get<int>("main.pre"), parset.get<int>("main.post"));
  multigridStep_->setIgnore(ignoreNodes);
  multigridStep_->setBaseSolver(baseSolver_);
  multigridStep_->setSmoother(&presmoother_, &postsmoother_);
  multigridStep_->setHasObstacle(nBodyAssembler_.totalHasObstacle_);
  multigridStep_->setObstacles(nBodyAssembler_.totalObstacles_);

  // create the transfer operators
  const int nCouplings = nBodyAssembler_.nCouplings();
  const auto grids = nBodyAssembler_.getGrids();
  for (size_t i=0; i<transferOperators_.size(); i++) {
    std::vector<const Dune::BitSetVector<1>*> coarseHasObstacle(nCouplings);
    std::vector<const Dune::BitSetVector<1>*> fineHasObstacle(nCouplings);

    std::vector<const Matrix*> mortarTransfer(nCouplings);
    std::vector<std::array<int,2> > gridIdx(nCouplings);

    for (int j=0; j<nCouplings; j++) {
      coarseHasObstacle[j]  = nBodyAssembler_.nonmortarBoundary_[j][i].getVertices();
      fineHasObstacle[j]    = nBodyAssembler_.nonmortarBoundary_[j][i+1].getVertices();

      mortarTransfer[j] = &nBodyAssembler_.contactCoupling_[j].mortarLagrangeMatrix(i);
      gridIdx[j]        = nBodyAssembler_.coupling_[j].gridIdx_;
    }

    transferOperators_[i] = new Dune::Contact::ContactMGTransfer<Vector>;
    transferOperators_[i]->setup(grids, i, i+1,
                                    nBodyAssembler_.localCoordSystems_[i],
                                    nBodyAssembler_.localCoordSystems_[i+1],
                                    coarseHasObstacle, fineHasObstacle,
                                    mortarTransfer,
                                    gridIdx);
  }

  multigridStep_->setTransferOperators(transferOperators_);
}

template <class DeformedGridTEMPLATE, class MatrixType, class VectorType>
SolverFactory<DeformedGrid, Matrix, Vector>::~SolverFactory() {
  for (auto &&x : transferOperators_)
    delete x;
}

template <class DeformedGridTEMPLATE, class MatrixType, class VectorType>
auto SolverFactory<DeformedGrid, Matrix, Vector>::getStep()
    -> std::shared_ptr<Step> {
  return multigridStep_;
}

#include "solverfactory_tmpl.cc"