#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <dune/fufem/boundarypatch.hh>
#include <dune/fufem/functions/constantfunction.hh>
#include <dune/fufem/assemblers/localassemblers/neumannboundaryassembler.hh>
#include <dune/fufem/assemblers/boundaryfunctionalassembler.hh>

#include <dune/tectonic/globallaursennonlinearity.hh>
#include <dune/tectonic/globalruinanonlinearity.hh>

#include "assemblers.hh"

// Assembles Neumann boundary term in f
template <class GridType, class GridView, class LocalVectorType, class FEBasis>
void assemble_neumann(GridView const &gridView, FEBasis const &feBasis,
                      Dune::BitSetVector<1> const &neumannNodes,
                      Dune::BlockVector<LocalVectorType> &f,
                      Dune::VirtualFunction<double, double> const &neumann,
                      double time) { // constant sample function on neumann
                                     // boundary
  BoundaryPatch<GridView> const neumannBoundary(gridView, neumannNodes);
  LocalVectorType SampleVector(0);
  neumann.evaluate(time, SampleVector[0]);
  SampleVector[1] = 0;
  ConstantFunction<LocalVectorType, LocalVectorType> const fNeumann(
      SampleVector);
  NeumannBoundaryAssembler<GridType, LocalVectorType> neumannBoundaryAssembler(
      fNeumann);

  BoundaryFunctionalAssembler<FEBasis>(feBasis, neumannBoundary).assemble(
      neumannBoundaryAssembler, f, true); // resize and zero output vector
}

// Assembles constant 1-function on frictional boundary in nodalIntegrals
template <class GridType, class GridView, class LocalVectorType, class FEBasis>
Dune::shared_ptr<Dune::BlockVector<Dune::FieldVector<double, 1>>>
assemble_frictional(GridView const &gridView, FEBasis const &feBasis,
                    Dune::BitSetVector<1> const &frictionalNodes) {
  typedef Dune::FieldVector<double, 1> Singleton;
  BoundaryPatch<GridView> const frictionalBoundary(gridView, frictionalNodes);
  ConstantFunction<LocalVectorType, Singleton> const constantOneFunction(1);
  NeumannBoundaryAssembler<GridType, Singleton> frictionalBoundaryAssembler(
      constantOneFunction);

  auto const nodalIntegrals = Dune::make_shared<Dune::BlockVector<Singleton>>();
  BoundaryFunctionalAssembler<FEBasis>(feBasis, frictionalBoundary)
      .assemble(frictionalBoundaryAssembler, *nodalIntegrals,
                true); // resize and zero output vector
  return nodalIntegrals;
}

template <class VectorType, class MatrixType>
Dune::shared_ptr<Dune::GlobalNonlinearity<VectorType, MatrixType> const>
assemble_nonlinearity(
    int size, Dune::ParameterTree const &parset,
    Dune::shared_ptr<Dune::BlockVector<Dune::FieldVector<double, 1>>>
        nodalIntegrals,
    Dune::shared_ptr<Dune::BlockVector<Dune::FieldVector<double, 1>>> state,
    double h) {
  typedef Dune::BlockVector<Dune::FieldVector<double, 1>> SingletonVectorType;
  // {{{ Assemble terms for the nonlinearity
  auto mu = Dune::make_shared<SingletonVectorType>(size);
  *mu = parset.get<double>("boundary.friction.mu");

  auto normalStress = Dune::make_shared<SingletonVectorType>(size);
  *normalStress = parset.get<double>("boundary.friction.normalstress");

  std::string const friction_model =
      parset.get<std::string>("boundary.friction.model");
  if (friction_model == std::string("Ruina")) {
    auto a = Dune::make_shared<SingletonVectorType>(size);
    *a = parset.get<double>("boundary.friction.ruina.a");

    auto eta = Dune::make_shared<SingletonVectorType>(size);
    *eta = parset.get<double>("boundary.friction.eta");

    auto b = Dune::make_shared<SingletonVectorType>(size);
    *b = parset.get<double>("boundary.friction.ruina.b");

    auto L = Dune::make_shared<SingletonVectorType>(size);
    *L = parset.get<double>("boundary.friction.ruina.L");

    return Dune::make_shared<
        Dune::GlobalRuinaNonlinearity<VectorType, MatrixType> const>(
        nodalIntegrals, a, mu, eta, normalStress, b, state, L, h);
  } else if (friction_model == std::string("Laursen")) {
    return
        // TODO: take state and h into account
        // FIXME: We should be using a quadratic rather than a linear function
        // here!
        Dune::make_shared<
            Dune::GlobalLaursenNonlinearity<Dune::LinearFunction, VectorType,
                                            MatrixType> const>(mu, normalStress,
                                                               nodalIntegrals);
  } else {
    assert(false);
  }
}

#include "assemblers_tmpl.cc"