#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <dune/common/exceptions.hh>

#include <dune/matrix-vector/axpy.hh>

#include <dune/solvers/norms/energynorm.hh>
#include <dune/solvers/solvers/loopsolver.hh>

#include <dune/contact/assemblers/nbodyassembler.hh>
#include <dune/contact/common/dualbasisadapter.hh>

#include <dune/localfunctions/lagrange/pqkfactory.hh>

#include <dune/functions/gridfunctions/gridfunction.hh>

#include <dune/geometry/quadraturerules.hh>
#include <dune/geometry/type.hh>
#include <dune/geometry/referenceelements.hh>

#include <dune/fufem/functions/basisgridfunction.hh>

#include "../data-structures/enums.hh"
#include "../data-structures/enumparser.hh"

#include "fixedpointiterator.hh"

#include "../utils/tobool.hh"
#include "../utils/debugutils.hh"

#include <dune/solvers/solvers/loopsolver.hh>
#include <dune/solvers/iterationsteps/truncatedblockgsstep.hh>
#include <dune/solvers/iterationsteps/multigridstep.hh>

#include "../spatial-solving/preconditioners/nbodycontacttransfer.hh"

#include "tnnmg/functional.hh"
#include "tnnmg/zerononlinearity.hh"
#include "solverfactory.hh"

void FixedPointIterationCounter::operator+=(
    FixedPointIterationCounter const &other) {
  iterations += other.iterations;
  multigridIterations += other.multigridIterations;
}

template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
FixedPointIterator<Factory, ContactNetwork, Updaters, ErrorNorms>::FixedPointIterator(
    Dune::ParameterTree const &parset,
    const ContactNetwork& contactNetwork,
    const IgnoreVector& ignoreNodes,
    GlobalFriction& globalFriction,
    const std::vector<const BitVector*>& bodywiseNonmortarBoundaries,
    const ErrorNorms& errorNorms)
    : parset_(parset),
      contactNetwork_(contactNetwork),
      ignoreNodes_(ignoreNodes),
      globalFriction_(globalFriction),
      bodywiseNonmortarBoundaries_(bodywiseNonmortarBoundaries),
      fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")),
      fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")),
      lambda_(parset.get<double>("v.fpi.lambda")),
      velocityMaxIterations_(parset.get<size_t>("v.solver.maximumIterations")),
      velocityTolerance_(parset.get<double>("v.solver.tolerance")),
      verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")),
      errorNorms_(errorNorms) {}

template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
FixedPointIterationCounter
FixedPointIterator<Factory, ContactNetwork, Updaters, ErrorNorms>::run(
    Updaters updaters,
    const std::vector<Matrix>& velocityMatrices, const std::vector<Vector>& velocityRHSs,
    std::vector<Vector>& velocityIterates) {

  std::cout << "FixedPointIterator::run()" << std::endl;

  const auto& nBodyAssembler = contactNetwork_.nBodyAssembler();

  // debugging
  const auto& contactCouplings = nBodyAssembler.getContactCouplings();
  for (size_t i=0; i<contactCouplings.size(); i++) {
    print(*contactCouplings[i]->nonmortarBoundary().getVertices(), "nonmortarBoundaries:");
  }

  const auto nBodies = nBodyAssembler.nGrids();

  std::vector<const Matrix*> matrices_ptr(nBodies);
  for (int i=0; i<nBodies; i++) {
      matrices_ptr[i] = &velocityMatrices[i];
  }

  // assemble full global contact problem
  Matrix bilinearForm;
  nBodyAssembler.assembleJacobian(matrices_ptr, bilinearForm);

  print(bilinearForm, "bilinearForm:");

  Vector totalRhs;
  nBodyAssembler.assembleRightHandSide(velocityRHSs, totalRhs);

  print(totalRhs, "totalRhs:");

  // get lower and upper obstacles
  const auto& totalObstacles = nBodyAssembler.totalObstacles_;
  Vector lower(totalObstacles.size());
  Vector upper(totalObstacles.size());

  for (size_t j=0; j<totalObstacles.size(); ++j) {
      const auto& totalObstaclesj = totalObstacles[j];
      auto& lowerj = lower[j];
      auto& upperj = upper[j];
    for (size_t d=0; d<dims; ++d) {
        lowerj[d] = totalObstaclesj[d][0];
        upperj[d] = totalObstaclesj[d][1];
    }
  }

  print(totalObstacles, "totalObstacles:");

  print(lower, "lower obstacles:");
  print(upper, "upper obstacles:");

  // compute velocity obstacles
  Vector vLower, vUpper;
  std::vector<Vector> u0, v0;
  updaters.rate_->extractOldVelocity(v0);
  updaters.rate_->extractOldDisplacement(u0);

  Vector totalu0, totalv0;
  nBodyAssembler.concatenateVectors(u0, totalu0);
  nBodyAssembler.concatenateVectors(v0, totalv0);

  updaters.rate_->velocityObstacles(totalu0, lower, totalv0, vLower);
  updaters.rate_->velocityObstacles(totalu0, upper, totalv0, vUpper);

  print(vLower, "vLower obstacles:");
  print(vUpper, "vUpper obstacles:");

  std::cout << "- Problem assembled: success" << std::endl;

  using LinearSolver = typename Dune::Solvers::LoopSolver<Vector, IgnoreVector>;
  using TransferOperator = NBodyContactTransfer<ContactNetwork, Vector>;
  using TransferOperators = std::vector<std::shared_ptr<TransferOperator>>;

  TransferOperators transfer(contactNetwork_.nLevels()-1);
  for (size_t i=0; i<transfer.size(); i++) {
      transfer[i] = std::make_shared<TransferOperator>();
      transfer[i]->setup(contactNetwork_, i, i+1);
  }

  // Remove any recompute filed so that initially the full transferoperator is assembled
  for (size_t i=0; i<transfer.size(); i++)
      std::dynamic_pointer_cast<TruncatedMGTransfer<Vector> >(transfer[i])->setRecomputeBitField(nullptr);

  auto smoother = TruncatedBlockGSStep<Matrix, Vector>{};
  auto linearMultigridStep = std::make_shared<Dune::Solvers::MultigridStep<Matrix, Vector> >();
  linearMultigridStep->setMGType(1, 3, 3);
  linearMultigridStep->setSmoother(smoother);
  linearMultigridStep->setTransferOperators(transfer);

  EnergyNorm<Matrix, Vector> mgNorm(*linearMultigridStep);
  LinearSolver mgSolver(linearMultigridStep, parset_.get<size_t>("solver.tnnmg.linear.maximumIterations"), parset_.get<double>("solver.tnnmg.linear.tolerance"), mgNorm, Solver::QUIET);

  print(ignoreNodes_, "ignoreNodes:");

  // set up functional and TNMMG solver
  using ZeroSolverFactory = SolverFactory<Functional, IgnoreVector>;
  Functional J(bilinearForm, totalRhs, ZeroNonlinearity(), vLower, vUpper);
  ZeroSolverFactory solverFactory(parset_.sub("solver.tnnmg"), J, mgSolver, ignoreNodes_);
  /*Functional J(bilinearForm, totalRhs, globalFriction_, vLower, vUpper);
  Factory solverFactory(parset_.sub("solver.tnnmg"), J, mgSolver, ignoreNodes_);*/
  auto step = solverFactory.step();

  std::cout << "- Functional and TNNMG step setup: success" << std::endl;

  EnergyNorm<Matrix, Vector> energyNorm(bilinearForm);
  LoopSolver<Vector> velocityProblemSolver(*step.get(), velocityMaxIterations_,
                                           velocityTolerance_, energyNorm,
                                           verbosity_, false); // absolute error

  size_t fixedPointIteration;
  size_t multigridIterations = 0;
  std::vector<ScalarVector> alpha(nBodies);
  updaters.state_->extractAlpha(alpha);
  for (fixedPointIteration = 0; fixedPointIteration < fixedPointMaxIterations_;
       ++fixedPointIteration) {

    print(alpha, "alpha:");

    // contribution from nonlinearity
    globalFriction_.updateAlpha(alpha);

    Vector totalVelocityIterate;
    nBodyAssembler.nodalToTransformed(velocityIterates, totalVelocityIterate);

    //print(velocityIterates, "velocityIterates:");
    //print(totalVelocityIterate, "totalVelocityIterate:");
    std::cout << "- FixedPointIteration iterate" << std::endl;

    // solve a velocity problem
    solverFactory.setProblem(totalVelocityIterate);

    std::cout << "- Velocity problem set" << std::endl;

    velocityProblemSolver.preprocess();
    std::cout << "-- Preprocessed" << std::endl;
    velocityProblemSolver.solve();
    std::cout << "-- Solved" << std::endl;

    const auto& tnnmgSol = step->getSol();

    std::cout << "FixPointIterator: Energy of TNNMG solution: " << J(tnnmgSol) << std::endl;

    nBodyAssembler.postprocess(tnnmgSol, velocityIterates);
    //nBodyAssembler.postprocess(totalVelocityIterate, velocityIterates);

    print(totalVelocityIterate, "totalVelocityIterate:");
    print(velocityIterates, "velocityIterates:");

    //DUNE_THROW(Dune::Exception, "Just need to stop here!");

    multigridIterations += velocityProblemSolver.getResult().iterations;

    std::vector<Vector> v_m;
    updaters.rate_->extractOldVelocity(v_m);

    for (size_t i=0; i<v_m.size(); i++) {
      v_m[i] *= 1.0 - lambda_;
      Dune::MatrixVector::addProduct(v_m[i], lambda_, velocityIterates[i]);
    }

    // extract relative velocities in mortar basis
    std::vector<Vector> v_rel;
    relativeVelocities(tnnmgSol, v_rel);

    //print(v_rel, "v_rel");

    std::cout << "- State problem set" << std::endl;

    // solve a state problem
    updaters.state_->solve(v_rel);

    std::cout << "-- Solved" << std::endl;

    std::vector<ScalarVector> newAlpha(nBodies);
    updaters.state_->extractAlpha(newAlpha);

    bool breakCriterion = true;
    for (int i=0; i<nBodies; i++) {
        if (alpha[i].size()==0 || newAlpha[i].size()==0)
            continue;

        print(alpha[i], "alpha i:");
        print(newAlpha[i], "new alpha i:");
        if (errorNorms_[i]->diff(alpha[i], newAlpha[i]) >= fixedPointTolerance_) {
            breakCriterion = false;
            std::cout << "fixedPoint error: " << errorNorms_[i]->diff(alpha[i], newAlpha[i]) << std::endl;
            break;
        }
    }

    if (lambda_ < 1e-12 or breakCriterion) {
      std::cout << "-FixedPointIteration finished! " << (lambda_ < 1e-12 ? "lambda" : "breakCriterion") << std::endl;
      fixedPointIteration++;
      break;
    }
    alpha = newAlpha;
  }

  std::cout << "-FixedPointIteration finished! " << std::endl;

  if (fixedPointIteration == fixedPointMaxIterations_)
    DUNE_THROW(Dune::Exception, "FPI failed to converge");

  updaters.rate_->postProcess(velocityIterates);

  // Cannot use return { fixedPointIteration, multigridIterations };
  // with gcc 4.9.2, see also http://stackoverflow.com/a/37777814/179927
  FixedPointIterationCounter ret;
  ret.iterations = fixedPointIteration;
  ret.multigridIterations = multigridIterations;
  return ret;
}

std::ostream &operator<<(std::ostream &stream,
                         FixedPointIterationCounter const &fpic) {
  return stream << "(" << fpic.iterations << "," << fpic.multigridIterations
                << ")";
}

template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
void FixedPointIterator<Factory, ContactNetwork, Updaters, ErrorNorms>::relativeVelocities(
    const Vector& v,
    std::vector<Vector>& v_rel) const {

    const auto& nBodyAssembler = contactNetwork_.nBodyAssembler();
    const size_t nBodies = nBodyAssembler.nGrids();
   // const auto& contactCouplings = nBodyAssembler.getContactCouplings();

    size_t globalIdx = 0;
    v_rel.resize(nBodies);
    for (size_t bodyIdx=0; bodyIdx<nBodies; bodyIdx++) {
        const auto& nonmortarBoundary = *bodywiseNonmortarBoundaries_[bodyIdx];
        auto& v_rel_ = v_rel[bodyIdx];

        v_rel_.resize(nonmortarBoundary.size());

        for (size_t i=0; i<v_rel_.size(); i++) {
            if (toBool(nonmortarBoundary[i])) {
                v_rel_[i] = v[globalIdx];
            }
            globalIdx++;
        }
    }


   /*
        boundaryNodes =

        const auto gridView = contactCouplings[couplingIndices[0]]->nonmortarBoundary().gridView();

        Dune::MultipleCodimMultipleGeomTypeMapper<
            decltype(gridView), Dune::MCMGVertexLayout> const vertexMapper(gridView, Dune::mcmgVertexLayout());

        for (auto it = gridView.template begin<block_size>(); it != gridView.template end<block_size>(); ++it) {
            const auto i = vertexMapper.index(*it);

            for (size_t j=0; j<couplingIndices.size(); j++) {
                const auto couplingIdx = couplingIndices[j];

                if (not contactCouplings[couplingIdx]->nonmortarBoundary().containsVertex(i))
                  continue;

                localToGlobal_.emplace_back(i);
                restrictions_.emplace_back(weights[bodyIdx][i], weightedNormalStress[bodyIdx][i],
                                          couplings[i]->frictionData()(geoToPoint(it->geometry())));
                break;
            }

            globalIdx++;
        }
        maxIndex_[bodyIdx] = globalIdx;
    }*/
}


#include "fixedpointiterator_tmpl.cc"