#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <dune/common/exceptions.hh>

#include <dune/matrix-vector/axpy.hh>

#include <dune/solvers/norms/energynorm.hh>
#include <dune/solvers/solvers/loopsolver.hh>

#include <dune/contact/assemblers/nbodyassembler.hh>
#include <dune/contact/common/dualbasisadapter.hh>

#include <dune/localfunctions/lagrange/pqkfactory.hh>

#include <dune/functions/gridfunctions/gridfunction.hh>

#include <dune/geometry/quadraturerules.hh>
#include <dune/geometry/type.hh>
#include <dune/geometry/referenceelements.hh>

#include <dune/fufem/functions/basisgridfunction.hh>

#include "../enums.hh"
#include "../enumparser.hh"

#include "fixedpointiterator.hh"

void FixedPointIterationCounter::operator+=(
    FixedPointIterationCounter const &other) {
  iterations += other.iterations;
  multigridIterations += other.multigridIterations;
}

template <class Factory, class Updaters, class ErrorNorm>
FixedPointIterator<Factory, Updaters, ErrorNorm>::FixedPointIterator(
    Factory &factory, Dune::ParameterTree const &parset,
    std::vector<std::shared_ptr<Nonlinearity>>& globalFriction, const ErrorNorm& errorNorm)
    : factory_(factory),
      step_(factory_.getStep()),
      parset_(parset),
      globalFriction_(globalFriction),
      fixedPointMaxIterations_(parset.get<size_t>("v.fpi.maximumIterations")),
      fixedPointTolerance_(parset.get<double>("v.fpi.tolerance")),
      lambda_(parset.get<double>("v.fpi.lambda")),
      velocityMaxIterations_(parset.get<size_t>("v.solver.maximumIterations")),
      velocityTolerance_(parset.get<double>("v.solver.tolerance")),
      verbosity_(parset.get<Solver::VerbosityMode>("v.solver.verbosity")),
      errorNorm_(errorNorm) {}

template <class Factory, class Updaters, class ErrorNorm>
FixedPointIterationCounter
FixedPointIterator<Factory, Updaters, ErrorNorm>::run(
    Updaters updaters, const std::vector<const Matrix*>& velocityMatrices, const std::vector<Vector>& velocityRHSs,
    std::vector<Vector>& velocityIterates) {

  EnergyNorm<Matrix, Vector> energyNorm(velocityMatrices[0]);
  LoopSolver<Vector> velocityProblemSolver(step_.get(), velocityMaxIterations_,
                                           velocityTolerance_, &energyNorm,
                                           verbosity_, false); // absolute error

  // assemble full global contact problem
  auto contactAssembler = factory_.getNBodyAssembler();

  Matrix bilinearForm;
  contactAssembler.assembleJacobian(velocityMatrices, bilinearForm);

  Vector totalRhs;
  contactAssembler.assembleRightHandSide(velocityRHSs, totalRhs);

  Vector totalVelocityIterate;
  contactAssembler.nodalToTransformed(velocityIterates, totalVelocityIterate);

  // contribution from nonlinearity
  // add gradient to rhs, hessian to matrix!?

  size_t fixedPointIteration;
  size_t multigridIterations = 0;
  std::vector<ScalarVector> alpha;
  updaters.state_->extractAlpha(alpha);
  for (fixedPointIteration = 0; fixedPointIteration < fixedPointMaxIterations_;
       ++fixedPointIteration) {

    // solve a velocity problem
    for (size_t i=0; i<alpha.size(); i++) {
      globalFriction_[i]->updateAlpha(alpha[i]);
    }

    step_->setProblem(bilinearForm, totalVelocityIterate, totalRhs);

    velocityProblemSolver.preprocess();
    velocityProblemSolver.solve();

    multigridIterations += velocityProblemSolver.getResult().iterations;

    std::vector<Vector> v_m;
    updaters.rate_->extractOldVelocity(v_m);

    for (size_t i=0; i<v_m.size(); i++) {
      v_m[i] *= 1.0 - lambda_;
      Dune::MatrixVector::addProduct(v_m[i], lambda_, velocityIterates[i]);
    }

    // compute relative velocities on contact boundaries
    relativeVelocities(v_m);

    // solve a state problem
    updaters.state_->solve(v_m);
    ScalarVector newAlpha;
   /* updaters.state_->extractAlpha(newAlpha);

    if (lambda_ < 1e-12 or
        errorNorm_.diff(alpha, newAlpha) < fixedPointTolerance_) {
      fixedPointIteration++;
      break;
    }
    alpha = newAlpha;*/
  }
  if (fixedPointIteration == fixedPointMaxIterations_)
    DUNE_THROW(Dune::Exception, "FPI failed to converge");

  updaters.rate_->postProcess(velocityIterates);

  // Cannot use return { fixedPointIteration, multigridIterations };
  // with gcc 4.9.2, see also http://stackoverflow.com/a/37777814/179927
  FixedPointIterationCounter ret;
  ret.iterations = fixedPointIteration;
  ret.multigridIterations = multigridIterations;
  return ret;
}

std::ostream &operator<<(std::ostream &stream,
                         FixedPointIterationCounter const &fpic) {
  return stream << "(" << fpic.iterations << "," << fpic.multigridIterations
                << ")";
}

template <class Factory, class Updaters, class ErrorNorm>
void FixedPointIterator<Factory, Updaters, ErrorNorm>::relativeVelocities(std::vector<Vector>& v_m) const {
  // needs assemblers to obtain basis
    /*
  std::vector<std::shared_ptr<MyAssembler>> assemblers(bodyCount);

  using field_type = typename Factory::Matrix::field_type;


  // adaptation of DualMortarCoupling::setup()

  const size_t dim = DeformedGrid::dimension;
  typedef typename DeformedGrid::LeafGridView GridView;

  //cache of local bases
  typedef Dune::PQkLocalFiniteElementCache<typename DeformedGrid::ctype, field_type, dim,1> FiniteElementCache1;
  FiniteElementCache1 cache1;

  // cache for the dual functions on the boundary
  using DualCache = Dune::Contact::DualBasisAdapter<GridView, field_type>;
  std::unique_ptr<DualCache> dualCache;
  dualCache = std::make_unique< Dune::Contact::DualBasisAdapterGlobal<GridView, field_type> >();

  // define FE grid functions
  std::vector<BasisGridFunction<VertexBasis, Vector>* > gridFunctions(v_m.size());
  for (size_t i=0; i<gridFunctions.size(); i++) {
    gridFunctions[i] = new BasisGridFunction<MyAssembler::VertexBasis, Vector>(assemblers[i]->vertexBasis, v_m[i]);
  }
 */

  /*
  for (size_t i=0; i<nBodyAssembler_.nCouplings(); i++) {
    const auto& coupling = nBodyAssembler_.getCoupling(i);
    auto glue = coupling.backend();

    const std::array<int, 2> gridIdx = coupling.gridIdx_;
    const int nonmortarGridIdx = ;
    const int mortarGridIdx = ;

    // loop over all intersections
    for (const auto& rIs : intersections(glue)) {
        const auto& inside = rIs.inside();

        if (!nonmortarBoundary_.contains(rIs.inside(),rIs.indexInInside()))
            continue;

        const auto& outside = rIs.outside();

        // types of the elements supporting the boundary segments in question
        Dune::GeometryType nonmortarEType = inside.type();
        Dune::GeometryType mortarEType    = outside.type();

        const auto& domainRefElement = Dune::ReferenceElements<ctype, dim>::general(nonmortarEType);
        const auto& targetRefElement = Dune::ReferenceElements<ctype, dim>::general(mortarEType);

        int noOfMortarVec = targetRefElement.size(dim);

        Dune::GeometryType nmFaceType = domainRefElement.type(rIs.indexInInside(),1);
        Dune::GeometryType mFaceType  = targetRefElement.type(rIs.indexInOutside(),1);

        // Select a quadrature rule
        // 2 in 2d and for integration over triangles in 3d.  If one (or both) of the two faces involved
        // are quadrilaterals, then the quad order has to be risen to 3 (4).
        int quadOrder = 2 + (!nmFaceType.isSimplex()) + (!mFaceType.isSimplex());
        const auto& quadRule = Dune::QuadratureRules<ctype, dim-1>::rule(rIs.type(), quadOrder);

        const auto& mortarFiniteElement = cache1.get(mortarEType);
        dualCache->bind(inside, rIs.indexInInside());

        std::vector<Dune::FieldVector<field_type,1> > mortarQuadValues, dualQuadValues;

        const auto& rGeom = rIs.geometry();
        const auto& rGeomOutside = rIs.geometryOutside();
        const auto& rGeomInInside = rIs.geometryInInside();
        const auto& rGeomInOutside = rIs.geometryInOutside();

        int nNonmortarFaceNodes = domainRefElement.size(rIs.indexInInside(),1,dim);
        std::vector<int> nonmortarFaceNodes;
        for (int i=0; i<nNonmortarFaceNodes; i++) {
          int faceIdxi = domainRefElement.subEntity(rIs.indexInInside(), 1, i, dim);
          nonmortarFaceNodes.push_back(faceIdxi);
        }

        for (const auto& quadPt : quadRule) {

            // compute integration element of overlap
            ctype integrationElement = rGeom.integrationElement(quadPt.position());

            // quadrature point positions on the reference element
            Dune::FieldVector<ctype,dim> nonmortarQuadPos = rGeomInInside.global(quadPt.position());
            Dune::FieldVector<ctype,dim> mortarQuadPos    = rGeomInOutside.global(quadPt.position());

            // The current quadrature point in world coordinates
            Dune::FieldVector<field_type,dim> nonmortarQpWorld = rGeom.global(quadPt.position());
            Dune::FieldVector<field_type,dim> mortarQpWorld    = rGeomOutside.global(quadPt.position());;

            // the gap direction (normal * gapValue)
            Dune::FieldVector<field_type,dim> gapVector = mortarQpWorld  - nonmortarQpWorld;

            //evaluate all shapefunctions at the quadrature point
            //nonmortarFiniteElement.localBasis().evaluateFunction(nonmortarQuadPos,nonmortarQuadValues);
            mortarFiniteElement.localBasis().evaluateFunction(mortarQuadPos,mortarQuadValues);
            dualCache->evaluateFunction(nonmortarQuadPos,dualQuadValues);

            // loop over all Lagrange multiplier shape functions
            for (int j=0; j<nNonmortarFaceNodes; j++) {

                int globalDomainIdx = indexSet0.subIndex(inside,nonmortarFaceNodes[j],dim);
                int rowIdx = globalToLocal[globalDomainIdx];

                weakObstacle_[rowIdx][0] += integrationElement * quadPt.weight()
                    * dualQuadValues[nonmortarFaceNodes[j]] * (gapVector*avNormals[globalDomainIdx]);

                // loop over all mortar shape functions
                for (int k=0; k<noOfMortarVec; k++) {

                    int colIdx  = indexSet1.subIndex(outside, k, dim);
                    if (!mortarBoundary_.containsVertex(colIdx))
                        continue;

                    // Integrate over the product of two shape functions
                    field_type mortarEntry =  integrationElement* quadPt.weight()* dualQuadValues[nonmortarFaceNodes[j]]* mortarQuadValues[k];

                    Dune::MatrixVector::addToDiagonal(mortarLagrangeMatrix_[rowIdx][colIdx], mortarEntry);

                }

            }

        }






  }


      // Create mapping from the global set of block dofs to the ones on the contact boundary
      std::vector<int> globalToLocal;
      nonmortarBoundary_.makeGlobalToLocal(globalToLocal);

      // loop over all intersections
      for (const auto& rIs : intersections(glue)) {

          if (!nonmortarBoundary_.contains(rIs.inside(),rIs.indexInInside()))
              continue;

          const auto& inside = rIs.inside();
          const auto& outside = rIs.outside();

          const auto& domainRefElement = Dune::ReferenceElements<ctype, dim>::general(inside.type());
          const auto& targetRefElement = Dune::ReferenceElements<ctype, dim>::general(outside.type());

          int nDomainVertices = domainRefElement.size(dim);
          int nTargetVertices = targetRefElement.size(dim);

          for (int j=0; j<nDomainVertices; j++) {

              int localDomainIdx = globalToLocal[indexSet0.subIndex(inside,j,dim)];

              // if the vertex is not contained in the restricted contact boundary then dismiss it
              if (localDomainIdx == -1)
                  continue;

              for (int k=0; k<nTargetVertices; k++) {
                  int globalTargetIdx = indexSet1.subIndex(outside,k,dim);
                  if (!mortarBoundary_.containsVertex(globalTargetIdx))
                      continue;

                  mortarIndices.add(localDomainIdx, globalTargetIdx);
              }
          }
      }
    */

}

#include "fixedpointiterator_tmpl.cc"