#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <dune/istl/scaledidmatrix.hh>

#include <dune/fufem/assemblers/localassemblers/boundarymassassembler.hh>
#include <dune/fufem/assemblers/localassemblers/l2functionalassembler.hh>
#include <dune/fufem/assemblers/localassemblers/neumannboundaryassembler.hh>
#include <dune/fufem/assemblers/localassemblers/normalstressboundaryassembler.hh>
#include <dune/fufem/assemblers/localassemblers/stvenantkirchhoffassembler.hh>
#include <dune/fufem/assemblers/localassemblers/variablecoefficientviscosityassembler.hh>
#include <dune/fufem/assemblers/localassemblers/vonmisesstressassembler.hh>
#include <dune/fufem/assemblers/localassemblers/weightedmassassembler.hh>
#include <dune/fufem/functions/basisgridfunction.hh>
#include <dune/fufem/functions/constantfunction.hh>
#include <dune/fufem/functiontools/p0p1interpolation.hh>
#include <dune/fufem/quadraturerules/quadraturerulecache.hh>

#include <dune/tectonic/frictionpotential.hh>
#include <dune/tectonic/globalratestatefriction.hh>

#include "assemblers.hh"

template <class GridView, int dimension>
MyAssembler<GridView, dimension>::MyAssembler(GridView const &_gridView) :
      cellBasis(_gridView),
      vertexBasis(_gridView),
      gridView(_gridView),
      cellAssembler(cellBasis, cellBasis),
      vertexAssembler(vertexBasis, vertexBasis) {}

template <class GridView, int dimension>
template <class LocalBoundaryFunctionalAssemblerType, class GlobalVectorType>
void MyAssembler<GridView, dimension>::assembleBoundaryFunctional(LocalBoundaryFunctionalAssemblerType& localAssembler,
                                GlobalVectorType& b,
                                const BoundaryPatch<GridView>& boundaryPatch,
                                bool initializeVector) const {

    vertexAssembler.assembleBoundaryFunctional(localAssembler, b, boundaryPatch, initializeVector);
}

template <class GridView, int dimension>
void MyAssembler<GridView, dimension>::assembleFrictionalBoundaryMass(
        BoundaryPatch<GridView> const &frictionalBoundary,
        ScalarMatrix &frictionalBoundaryMass) const {

  BoundaryMassAssembler<Grid, BoundaryPatch<GridView>, LocalVertexBasis,
                        LocalVertexBasis, Dune::FieldMatrix<double, 1, 1>> const
      frictionalBoundaryMassAssembler(frictionalBoundary);
  vertexAssembler.assembleOperator(frictionalBoundaryMassAssembler,
                                   frictionalBoundaryMass);
}

template <class GridView, int dimension>
void MyAssembler<GridView, dimension>::assembleMass(
        Dune::VirtualFunction<LocalVector, LocalScalarVector> const & densityFunction,
        Matrix &M) const {

  // NOTE: We treat the weight as a constant function
  QuadratureRuleKey quadKey(dimension, 0);

  WeightedMassAssembler<Grid, LocalVertexBasis, LocalVertexBasis,
                        Dune::VirtualFunction<LocalVector, LocalScalarVector>,
                        Dune::ScaledIdentityMatrix<double, dimension>>
      localWeightedMass(gridView.grid(), densityFunction, quadKey);
  vertexAssembler.assembleOperator(localWeightedMass, M);
}

template <class GridView, int dimension>
void MyAssembler<GridView, dimension>::assembleElasticity(
        double E,
        double nu,
        Matrix &A) const {

  StVenantKirchhoffAssembler<Grid, LocalVertexBasis, LocalVertexBasis> const
      localStiffness(E, nu);
  vertexAssembler.assembleOperator(localStiffness, A);
}

template <class GridView, int dimension>
void MyAssembler<GridView, dimension>::assembleViscosity(
        Dune::VirtualFunction<LocalVector, LocalScalarVector> const &shearViscosity,
        Dune::VirtualFunction<LocalVector, LocalScalarVector> const &bulkViscosity,
        Matrix &C) const {

  // NOTE: We treat the weights as constant functions
  QuadratureRuleKey shearViscosityKey(dimension, 0);
  QuadratureRuleKey bulkViscosityKey(dimension, 0);
  VariableCoefficientViscosityAssembler<
      Grid, LocalVertexBasis, LocalVertexBasis,
      Dune::VirtualFunction<LocalVector, LocalScalarVector>> const
      localViscosity(gridView.grid(), shearViscosity, bulkViscosity,
                     shearViscosityKey, bulkViscosityKey);
  vertexAssembler.assembleOperator(localViscosity, C);
}

template <class GridView, int dimension>
void MyAssembler<GridView, dimension>::assembleBodyForce(
        Dune::VirtualFunction<LocalVector, LocalVector> const &gravityField,
        Vector &f) const {

  L2FunctionalAssembler<Grid, LocalVertexBasis, LocalVector>
      gravityFunctionalAssembler(gravityField);
  vertexAssembler.assembleFunctional(gravityFunctionalAssembler, f);
}

template <class GridView, int dimension>
void MyAssembler<GridView, dimension>::assembleNeumann(
        BoundaryPatch<GridView> const &neumannBoundary,
        Vector &f,
        Dune::VirtualFunction<double, double> const &neumann,
        double relativeTime) const {

  LocalVector localNeumann(0);
  neumann.evaluate(relativeTime, localNeumann[0]);
  NeumannBoundaryAssembler<Grid, LocalVector> neumannBoundaryAssembler(
      std::make_shared<ConstantFunction<LocalVector, LocalVector>>(
          localNeumann));
  vertexAssembler.assembleBoundaryFunctional(neumannBoundaryAssembler, f,
                                             neumannBoundary);
}

template <class GridView, int dimension>
void MyAssembler<GridView, dimension>::assembleWeightedNormalStress(
        BoundaryPatch<GridView> const &frictionalBoundary,
        ScalarVector &weightedNormalStress,
        double youngModulus,
        double poissonRatio,
        Vector const &displacement) const {

  BasisGridFunction<VertexBasis, Vector> displacementFunction(vertexBasis,
                                                              displacement);
  Vector traction(cellBasis.size());
  NormalStressBoundaryAssembler<Grid> tractionBoundaryAssembler(
      youngModulus, poissonRatio, &displacementFunction, 1);
  cellAssembler.assembleBoundaryFunctional(tractionBoundaryAssembler, traction,
                                           frictionalBoundary);

  auto const nodalTractionAverage =
      interpolateP0ToP1(frictionalBoundary, traction);

  ScalarVector weights;
  {
    NeumannBoundaryAssembler<Grid, typename ScalarVector::block_type>
        frictionalBoundaryAssembler(
            std::make_shared<ConstantFunction<
                LocalVector, typename ScalarVector::block_type>>(1));
    vertexAssembler.assembleBoundaryFunctional(frictionalBoundaryAssembler,
                                               weights, frictionalBoundary);
  }
  auto const normals = frictionalBoundary.getNormals();
  for (size_t i = 0; i < vertexBasis.size(); ++i)
    weightedNormalStress[i] =
        std::fmin(normals[i] * nodalTractionAverage[i], 0) * weights[i];
}

template <class GridView, int dimension>
void MyAssembler<GridView, dimension>::assembleVonMisesStress(
        double youngModulus,
        double poissonRatio,
        Vector const &u,
        ScalarVector &stress) const {

  auto const gridDisplacement =
      std::make_shared<BasisGridFunction<VertexBasis, Vector> const>(
          vertexBasis, u);
  VonMisesStressAssembler<Grid, LocalCellBasis> localStressAssembler(
      youngModulus, poissonRatio, gridDisplacement);
  cellAssembler.assembleFunctional(localStressAssembler, stress);
}

#include "assemblers_tmpl.cc"