#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#ifdef HAVE_IPOPT
#undef HAVE_IPOPT
#endif

#ifndef srcdir
#error srcdir unset
#endif

#ifndef DIM
#error DIM unset
#endif

#ifndef HAVE_PYTHON
#error Python is required
#endif

#if !HAVE_ALUGRID
#error ALUGRID is required
#endif

#include <cmath>
#include <exception>
#include <iostream>

#include <boost/format.hpp>

#include <Python.h>

#include <dune/common/bitsetvector.hh>
#include <dune/common/exceptions.hh>
#include <dune/common/fmatrix.hh>
#include <dune/common/function.hh>
#include <dune/common/fvector.hh>
#include <dune/common/parametertree.hh>
#include <dune/common/parametertreeparser.hh>
#include <dune/common/shared_ptr.hh>
#include <dune/common/timer.hh>
#include <dune/grid/alugrid.hh>
#include <dune/grid/common/mcmgmapper.hh>
#include <dune/grid/utility/structuredgridfactory.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/bvector.hh>

#include <dune/fufem/assemblers/functionalassembler.hh>
#include <dune/fufem/assemblers/localassemblers/l2functionalassembler.hh>
#include <dune/fufem/assemblers/localassemblers/massassembler.hh>
#include <dune/fufem/assemblers/localassemblers/stvenantkirchhoffassembler.hh>
#include <dune/fufem/assemblers/localassemblers/vonmisesstressassembler.hh>
#include <dune/fufem/assemblers/operatorassembler.hh>
#include <dune/fufem/dunepython.hh>
#include <dune/fufem/functions/basisgridfunction.hh>
#include <dune/fufem/functions/constantfunction.hh>
#include <dune/fufem/functionspacebases/p0basis.hh>
#include <dune/fufem/functionspacebases/p1nodalbasis.hh>
#include <dune/fufem/sharedpointermap.hh>
#include <dune/solvers/norms/energynorm.hh>
#include <dune/solvers/solvers/loopsolver.hh>
#include <dune/solvers/solvers/solver.hh> // Solver::FULL

#include <dune/tectonic/myblockproblem.hh>
#include <dune/tectonic/myconvexproblem.hh>

#include "assemblers.hh"
#include "compute_state.hh"
#include "compute_state_ruina.hh"
#include "mysolver.hh"
#include "vtk.hh"

#include "enums.hh"
#include "enum_parser.cc"
#include "enum_state_model.cc"
#include "enum_scheme.cc"

#include "timestepping.cc"

int const dim = DIM;

template <class GridView, class GridCorner>
void setup_boundary(GridView const &gridView,
                    Dune::BitSetVector<dim> &ignoreNodes,
                    Dune::BitSetVector<1> &neumannNodes,
                    Dune::BitSetVector<1> &frictionalNodes,
                    GridCorner const &lowerLeft, GridCorner const &upperRight) {
  typedef typename GridView::template Codim<dim>::Iterator VertexLeafIterator;

  Dune::MultipleCodimMultipleGeomTypeMapper<
      GridView, Dune::MCMGVertexLayout> const myVertexMapper(gridView);

  for (auto it = gridView.template begin<dim>();
       it != gridView.template end<dim>(); ++it) {
    assert(it->geometry().corners() == 1);
    Dune::FieldVector<double, dim> const coordinates = it->geometry().corner(0);
    size_t const id = myVertexMapper.map(*it);
    if (coordinates[1] == upperRight[1])
      ignoreNodes[id] = true;
    else if (coordinates[1] == lowerLeft[1]) {
      frictionalNodes[id] = true;
      ignoreNodes[id][1] = true; // Zero displacement in direction y
    } else if (coordinates[0] == lowerLeft[0] ||
               coordinates[0] == upperRight[0])
      neumannNodes[id] = true;
  }
}

int main(int argc, char *argv[]) {
  try {
    typedef SharedPointerMap<std::string, Dune::VirtualFunction<double, double>>
    FunctionMap;
    FunctionMap functions;
    {
      Python::start();

      Python::run("import sys");
      Python::run("sys.path.append('" srcdir "')");

      Python::import("one-body-sample").get("Functions").toC<FunctionMap::Base>(
          functions);
    }

    Dune::ParameterTree parset;
    Dune::ParameterTreeParser::readINITree(srcdir "/one-body-sample.parset",
                                           parset);
    Dune::ParameterTreeParser::readOptions(argc, argv, parset);

    Dune::Timer timer;

    typedef Dune::FieldVector<double, dim> SmallVector;
    typedef Dune::FieldMatrix<double, dim, dim> SmallMatrix;
    typedef Dune::BCRSMatrix<SmallMatrix> MatrixType;
    typedef Dune::BlockVector<SmallVector> VectorType;
    typedef Dune::BlockVector<Dune::FieldVector<double, 1>> SingletonVectorType;

    auto const E = parset.get<double>("body.E");
    auto const nu = parset.get<double>("body.nu");
    auto const solver_tolerance = parset.get<double>("solver.tolerance");

    auto const refinements = parset.get<size_t>("grid.refinements");

    auto const verbose = parset.get<bool>("verbose");
    Solver::VerbosityMode const verbosity =
        verbose ? Solver::FULL : Solver::QUIET;

    // {{{ Set up grid
    typedef Dune::ALUGrid<dim, dim, Dune::simplex, Dune::nonconforming>
    GridType;
    Dune::FieldVector<typename GridType::ctype, dim> lowerLeft(0);
    Dune::FieldVector<typename GridType::ctype, dim> upperRight(1);
    Dune::array<unsigned int, dim> elements;
    std::fill(elements.begin(), elements.end(), 1);
    auto grid = Dune::StructuredGridFactory<GridType>::createSimplexGrid(
        lowerLeft, upperRight, elements);

    grid->globalRefine(refinements);
    size_t const finestSize = grid->size(grid->maxLevel(), dim);

    typedef GridType::LeafGridView GridView;
    GridView const leafView = grid->leafView();
    // }}}

    // Set up bases
    typedef P0Basis<GridView, double> P0Basis;
    typedef P1NodalBasis<GridView, double> P1Basis;
    P0Basis const p0Basis(leafView);
    P1Basis const p1Basis(leafView);

    MatrixType massMatrix;
    VectorType gravityFunctional;
    {
      timer.reset();

      // Assemble mass matrix and compute density
      MassAssembler<GridType, P1Basis::LocalFiniteElement,
                    P1Basis::LocalFiniteElement> const localMass;
      OperatorAssembler<P1Basis, P1Basis>(p1Basis, p1Basis)
          .assemble(localMass, massMatrix);

      // We have volume*gravity*density/area = normalstress (V*g*rho/A =
      // sigma_n)
      double volume = 1.0;
      for (int i = 0; i < dim; ++i)
        volume *= (upperRight[i] - lowerLeft[i]);

      double area = 1.0;
      for (int i = 0; i < dim; ++i)
        if (i != 1)
          area *= (upperRight[i] - lowerLeft[i]);

      double const gravity = 9.81;
      double const normalStress =
          parset.get<double>("boundary.friction.normalstress");

      // rho    = sigma     * A       / V   / g
      // kg/m^d = N/m^(d-1) * m^(d-1) / m^d / (N/kg)
      double const density = normalStress * area / volume / gravity;

      massMatrix *= density;
      if (parset.get<bool>("enable_timer"))
        std::cerr << "Assembled mass matrix in " << timer.elapsed() << "s"
                  << std::endl;

      // Compute gravitational body force
      SmallVector weightedGravitationalDirection(0);
      weightedGravitationalDirection[1] = -density * gravity;
      ConstantFunction<SmallVector, SmallVector> const gravityFunction(
          weightedGravitationalDirection);
      L2FunctionalAssembler<GridType, SmallVector> gravityFunctionalAssembler(
          gravityFunction);
      FunctionalAssembler<P1Basis>(p1Basis)
          .assemble(gravityFunctionalAssembler, gravityFunctional, true);
    }

    // Assemble elastic force on the body
    MatrixType stiffnessMatrix;
    {
      timer.reset();
      StVenantKirchhoffAssembler<GridType, P1Basis::LocalFiniteElement,
                                 P1Basis::LocalFiniteElement> const
      localStiffness(E, nu);
      OperatorAssembler<P1Basis, P1Basis>(p1Basis, p1Basis)
          .assemble(localStiffness, stiffnessMatrix);
      if (parset.get<bool>("enable_timer"))
        std::cerr << "Assembled stiffness matrix in " << timer.elapsed() << "s"
                  << std::endl;
    }
    EnergyNorm<MatrixType, VectorType> energyNorm(stiffnessMatrix);

    // Set up the boundary
    Dune::BitSetVector<dim> ignoreNodes(finestSize, false);
    Dune::BitSetVector<1> neumannNodes(finestSize, false);
    Dune::BitSetVector<1> frictionalNodes(finestSize, false);
    setup_boundary(leafView, ignoreNodes, neumannNodes, frictionalNodes,
                   lowerLeft, upperRight);

    auto const nodalIntegrals =
        assemble_frictional<GridType, GridView, SmallVector, P1Basis>(
            leafView, p1Basis, frictionalNodes);

    // {{{ Initialise vectors
    VectorType u(finestSize);
    u = 0.0;
    VectorType u_old(finestSize);
    u_old = 0.0; // Has to be zero!
    VectorType u_old_old(finestSize);

    VectorType ud(finestSize);
    ud = 0.0;
    VectorType ud_old(finestSize);
    ud_old = 0.0;

    VectorType udd(finestSize);
    udd = 0.0;
    VectorType udd_old(finestSize);
    udd_old = 0.0;

    SingletonVectorType alpha_old(finestSize);
    alpha_old = parset.get<double>("boundary.friction.state.initial");
    SingletonVectorType alpha(alpha_old);

    SingletonVectorType vonMisesStress;
    // }}}

    typedef MyConvexProblem<MatrixType, VectorType> MyConvexProblemType;
    typedef MyBlockProblem<MyConvexProblemType> MyBlockProblemType;

    // Set up TNNMG solver
    MySolver<dim, MatrixType, VectorType, GridType, MyBlockProblemType>
    mySolver(parset.sub("solver.tnnmg"), refinements, solver_tolerance, *grid,
             ignoreNodes);

    std::fstream octave_writer("data", std::fstream::out);
    std::fstream coefficient_writer("coefficient", std::fstream::out);
    std::fstream velocity_stepping_writer("velocity_stepping",
                                          std::fstream::out);
    timer.reset();

    auto const L = parset.get<double>("boundary.friction.ruina.L");
    auto const a = parset.get<double>("boundary.friction.ruina.a");
    auto const b = parset.get<double>("boundary.friction.ruina.b");
    auto const eta = parset.get<double>("boundary.friction.eta");
    auto const mu = parset.get<double>("boundary.friction.mu");
    auto const timesteps = parset.get<size_t>("timesteps");
    double const tau = 1.0 / timesteps;

    octave_writer << "# name: A" << std::endl << "# type: matrix" << std::endl
                  << "# rows: " << timesteps << std::endl << "# columns: 3"
                  << std::endl;

    velocity_stepping_writer << "# name: B" << std::endl << "# type: matrix"
                             << std::endl << "# rows: " << timesteps
                             << std::endl << "# columns: 2" << std::endl;

    auto const &dirichletFunction = functions.get("dirichletCondition");
    auto const &neumannFunction = functions.get("neumannCondition");

    // Find a (somewhat random) frictional node
    size_t first_frictional_node = 0;
    while (!frictionalNodes[first_frictional_node][0] &&
           first_frictional_node < frictionalNodes.size())
      ++first_frictional_node;

    for (size_t run = 1; run <= timesteps; ++run) {
      double const time = tau * run;
      {
        VectorType ell(finestSize);
        assemble_neumann<GridType, GridView, SmallVector, P1Basis>(
            leafView, p1Basis, neumannNodes, ell, neumannFunction, time);
        ell += gravityFunctional;

        MatrixType problem_A;
        VectorType problem_rhs(finestSize);
        VectorType problem_iterate(finestSize);

        VectorType *u_old_old_ptr = (run == 1) ? nullptr : &u_old_old;

        Dune::shared_ptr<TimeSteppingScheme<VectorType, MatrixType,
                                            decltype(dirichletFunction), dim>>
        timeSteppingScheme;
        {
          switch (parset.get<Config::scheme>("timeSteppingScheme")) {
            case Config::ImplicitTwoStep:
              if (run != 1) {
                timeSteppingScheme = Dune::make_shared<ImplicitTwoStep<
                    VectorType, MatrixType, decltype(dirichletFunction), dim>>(
                    ell, stiffnessMatrix, u_old, u_old_old_ptr, ignoreNodes,
                    dirichletFunction, time, tau);
                break;
              }
            // Fall through
            case Config::ImplicitEuler:
              timeSteppingScheme = Dune::make_shared<ImplicitEuler<
                  VectorType, MatrixType, decltype(dirichletFunction), dim>>(
                  ell, stiffnessMatrix, u_old, u_old_old_ptr, ignoreNodes,
                  dirichletFunction, time, tau);
              break;
            case Config::Newmark:
              timeSteppingScheme = Dune::make_shared<Newmark<
                  VectorType, MatrixType, decltype(dirichletFunction), dim>>(
                  ell, stiffnessMatrix, massMatrix, u_old, ud_old, udd_old,
                  ignoreNodes, dirichletFunction, time, tau);
              break;
          }
        }
        timeSteppingScheme->setup(problem_rhs, problem_iterate, problem_A);

        VectorType ud_saved = ud_old;
        auto const state_fpi_max =
            parset.get<size_t>("solver.tnnmg.fixed_point_iterations");
        for (size_t state_fpi = 1; state_fpi <= state_fpi_max; ++state_fpi) {
          auto myGlobalNonlinearity =
              assemble_nonlinearity<MatrixType, VectorType>(
                  parset.sub("boundary.friction"), *nodalIntegrals, alpha);

          MyConvexProblemType const myConvexProblem(
              problem_A, *myGlobalNonlinearity, problem_rhs);
          MyBlockProblemType myBlockProblem(parset, myConvexProblem);
          auto multigridStep = mySolver.getSolver();
          multigridStep->setProblem(problem_iterate, myBlockProblem);

          LoopSolver<VectorType> overallSolver(
              multigridStep, parset.get<size_t>("solver.tnnmg.maxiterations"),
              solver_tolerance, &energyNorm, verbosity,
              false); // absolute error
          overallSolver.solve();

          timeSteppingScheme->extractDisplacement(problem_iterate, u);
          timeSteppingScheme->extractVelocity(problem_iterate, ud);

          udd = 0;
          Arithmetic::addProduct(udd, 2.0 / tau, ud);
          Arithmetic::addProduct(udd, -2.0 / tau, ud_old);
          Arithmetic::addProduct(udd, -1.0, udd_old);

          // Update the state
          for (size_t i = 0; i < frictionalNodes.size(); ++i) {
            if (frictionalNodes[i][0]) {
              double const unorm = ud[i].two_norm() * tau;

              // // the (logarithmic) steady state corresponding to the
              // // current velocity
              // std::cout << std::log(L/unorm * tau) << std::endl;

              switch (parset.get<Config::state_model>(
                  "boundary.friction.state.model")) {
                case Config::Dieterich:
                  alpha[i] =
                      state_update_dieterich(tau, unorm / L, alpha_old[i]);
                  break;
                case Config::Ruina:
                  alpha[i] = state_update_ruina(tau, unorm / L, alpha_old[i]);
                  break;
              }
            }
          }
          if (parset.get<bool>("printProgress")) {
            std::cerr << '.';
            std::cerr.flush();
          }
          if (energyNorm.diff(ud_saved, ud) <
              parset.get<double>("solver.tnnmg.fixed_point_tolerance"))
            break;
          else
            ud_saved = ud;

          if (state_fpi == state_fpi_max)
            std::cerr << "[ref = " << refinements
                      << "]: FPI did not converge after " << state_fpi_max
                      << " iterations" << std::endl;
        }
        if (parset.get<bool>("printProgress"))
          std::cerr << std::endl;

        // Record the state, (scaled) displacement, and Neumann
        // condition at a fixed node
        if (parset.get<bool>("writeEvolution")) {
          double out;
          neumannFunction.evaluate(time, out);
          octave_writer << alpha[first_frictional_node][0] << " "
                        << u[first_frictional_node][0] * 1e6 << " " << out
                        << std::endl;
        }

        // Comparison with the analytic solution of a velocity stepping test
        // with the Ruina state evolution law.
        // Jumps at 120 and 360 timesteps; v1 = .6 * v2;
        if (parset.get<bool>("printVelocitySteppingComparison")) {
          double const v = ud[first_frictional_node].two_norm() / L;

          double const euler = alpha[first_frictional_node];
          double direct;
          if (run < 120) {
            direct = euler;
          } else if (run < 360) {
            double const v2 = v;
            double const v1 = 0.6 * v2;
            direct = std::log(
                1.0 / v2 *
                std::pow((v2 / v1), std::exp(-v2 * (run - 120) * tau)));
          } else {
            double const v1 = v;
            double const v2 = v1 / 0.6;
            direct = std::log(
                1.0 / v1 *
                std::pow((v1 / v2), std::exp(-v1 * (run - 360) * tau)));
          }
          velocity_stepping_writer << euler << " " << direct << std::endl;
        }

        // Record the coefficient of friction at a fixed node
        if (parset.get<bool>("printCoefficient")) {
          double const V = ud[first_frictional_node].two_norm();
          double const state = alpha[first_frictional_node];

          coefficient_writer << (mu + a * std::log(V * eta) +
                                 b * (state - std::log(eta * L))) << std::endl;
        }
      }

      if (parset.get<bool>("printFrictionalVelocity")) {
        for (size_t i = 0; i < frictionalNodes.size(); ++i)
          if (frictionalNodes[i][0])
            std::cout << ud[i][0] << " ";
        std::cout << std::endl;
      }
      u_old_old = u_old;
      u_old = u;
      ud_old = ud;
      udd_old = udd;
      alpha_old = alpha;

      // Compute von Mises stress and write everything to a file
      if (parset.get<bool>("writeVTK")) {
        VonMisesStressAssembler<GridType> localStressAssembler(
            E, nu,
            Dune::make_shared<BasisGridFunction<P1Basis, VectorType> const>(
                p1Basis, u));
        FunctionalAssembler<P0Basis>(p0Basis)
            .assemble(localStressAssembler, vonMisesStress, true);

        writeVtk<P1Basis, P0Basis, VectorType, SingletonVectorType, GridView>(
            p1Basis, u, alpha, p0Basis, vonMisesStress, leafView,
            (boost::format("obs%d") % run).str());
      }
    }
    if (parset.get<bool>("enable_timer"))
      std::cerr << std::endl << "Making " << timesteps << " time steps took "
                << timer.elapsed() << "s" << std::endl;

    octave_writer.close();
    coefficient_writer.close();
    velocity_stepping_writer.close();

    Python::stop();
  }
  catch (Dune::Exception &e) {
    Dune::derr << "Dune reported error: " << e << std::endl;
  }
  catch (std::exception &e) {
    std::cerr << "Standard exception: " << e.what() << std::endl;
  }
}