// Based on dune/tnnmg/problem-classes/blocknonlineartnnmgproblem.hh

#ifndef MY_BLOCK_PROBLEM_HH
#define MY_BLOCK_PROBLEM_HH

#include <dune/common/bitsetvector.hh>
#include <dune/common/nullptr.hh>
#include <dune/common/parametertree.hh>

// Just for debugging
#include "dune/solvers/computeenergy.hh"

#include <dune/fufem/arithmetic.hh>
#include <dune/tnnmg/problem-classes/bisection.hh>

#include "globalnonlinearity.hh"
#include "minimisation.hh"
#include "mydirectionalconvexfunction.hh"
#include "ellipticenergy.hh"

/* Just for debugging */
template <class MatrixType, class VectorType>
double computeEnergy(
    MatrixType const &A, VectorType const &x, VectorType const &b,
    Dune::GlobalNonlinearity<MatrixType, VectorType> const &phi) {
  return computeEnergy(A, x, b) + phi(x);
}

/** \brief Base class for problems where each block can be solved with a
 * modified gradient method */
template <class MyConvexProblemTypeTEMPLATE> class MyBlockProblem {
public:
  using MyConvexProblemType = MyConvexProblemTypeTEMPLATE;
  using VectorType = typename MyConvexProblemType::VectorType;
  using MatrixType = typename MyConvexProblemType::MatrixType;
  using LocalVectorType = typename MyConvexProblemType::LocalVectorType;
  using LocalMatrixType = typename MyConvexProblemType::LocalMatrixType;

  int static const block_size = MyConvexProblemType::block_size;
  int static const coarse_block_size = block_size;

  /** \brief Solves one local system using a modified gradient method */
  class IterateObject;

  struct Linearization {
    static const int block_size = coarse_block_size;

    using LocalMatrixType =
        typename MyBlockProblem<MyConvexProblemType>::LocalMatrixType;
    using MatrixType =
        Dune::BCRSMatrix<typename Linearization::LocalMatrixType>;
    using VectorType =
        Dune::BlockVector<Dune::FieldVector<double, Linearization::block_size>>;
    using BitVectorType = Dune::BitSetVector<Linearization::block_size>;
    typename Linearization::MatrixType A;
    typename Linearization::VectorType b;
    typename Linearization::BitVectorType ignore;

    Dune::BitSetVector<Linearization::block_size> truncation;
  };

  MyBlockProblem(Dune::ParameterTree const &parset,
                 MyConvexProblemType const &problem)
      : parset(parset), problem(problem) {
    bisection = Bisection(0.0, // acceptError: Stop if the search interval has
                               // become smaller than this number
                          parset.get<double>("bisection.acceptFactor"),
                          parset.get<double>("bisection.requiredResidual"),
                          true, // fastQuadratic
                          0);   // acceptance factor for inexact minimization
  }

  std::string getOutput(bool header = false) const {
    if (header) {
      outStream.str("");
      for (int j = 0; j < block_size; ++j)
        outStream << "  trunc" << std::setw(2) << j;
    }
    std::string s = outStream.str();
    outStream.str("");
    return s;
  }

  void projectCoarseCorrection(VectorType const &u,
                               typename Linearization::VectorType const &v,
                               VectorType &projected_v,
                               Linearization const &linearization) const {
    projected_v = v;
    for (size_t i = 0; i < v.size(); ++i)
      for (int j = 0; j < block_size; ++j)
        if (linearization.truncation[i][j])
          projected_v[i][j] = 0;
  }

  double computeDampingParameter(VectorType const &u,
                                 VectorType const &projected_v) const {
    VectorType v = projected_v;

    double const vnorm = v.two_norm();
    if (vnorm <= 0)
      return 1.0;

    v /= vnorm; // Rescale for numerical stability

    VectorType tmp = problem.f;
    Arithmetic::addProduct(tmp, -1.0, problem.A, u);
    double const localb = tmp * v;

    problem.A.mv(v, tmp);
    double const localA = tmp * v;

    /*
      1/2 <A(u + hv),u + hv> - <b, u + hv>
      = 1/2 <Av,v> h^2 - <b - Au, v> h + const.

      localA = <Av,v>
      localb = <b - Au, v>
    */

    MyDirectionalConvexFunction<
        Dune::GlobalNonlinearity<MatrixType, VectorType>> const psi(localA,
                                                                    localb,
                                                                    problem.phi,
                                                                    u, v);

    Interval<double> D;
    psi.subDiff(0, D);
    // NOTE: Numerical instability can actually get us here
    if (D[1] > 0)
      return 0;

    int bisectionsteps = 0;
    Bisection bisection(0.0, 1.0, 1e-12, true, 0);                      // TODO
    return bisection.minimize(psi, vnorm, 0.0, bisectionsteps) / vnorm; // TODO
  }

  void assembleTruncate(VectorType const &u, Linearization &linearization,
                        Dune::BitSetVector<block_size> const &ignore) const {
    // we can just copy the ignore information
    linearization.ignore = ignore;

    // determine truncation pattern
    linearization.truncation.resize(u.size());
    linearization.truncation.unsetAll();
    for (size_t i = 0; i < u.size(); ++i) {
      if (problem.phi.regularity(i, u[i]) > 1e8) { // TODO: Make customisable
        linearization.truncation[i] = true;
        continue;
      }

      for (int j = 0; j < block_size; ++j)
        if (linearization.ignore[i][j])
          linearization.truncation[i][j] = true;
    }

    // construct sparsity pattern for linearization
    Dune::MatrixIndexSet indices(problem.A.N(), problem.A.M());
    indices.import(problem.A);
    problem.phi.addHessianIndices(indices);

    // construct matrix from pattern and initialize it
    indices.exportIdx(linearization.A);
    linearization.A = 0.0;

    // compute quadratic part of hessian (linearization.A += problem.A)
    for (size_t i = 0; i < problem.A.N(); ++i) {
      auto const end = problem.A[i].end();
      for (auto it = problem.A[i].begin(); it != end; ++it)
        Arithmetic::addProduct(linearization.A[i][it.index()], 1.0, *it);
    }

    // compute nonlinearity part of hessian
    problem.phi.addHessian(u, linearization.A);

    // compute quadratic part of gradient
    linearization.b.resize(u.size());
    problem.A.mv(u, linearization.b);
    linearization.b -= problem.f;

    // compute nonlinearity part of gradient
    problem.phi.addGradient(u, linearization.b);

    // -grad is needed for Newton step
    linearization.b *= -1.0;

    // b should be a descent direction
    {
      VectorType const direction = linearization.b;
      VectorType tmp = linearization.b;                      //  b
      Arithmetic::addProduct(tmp, -1.0, linearization.A, u); // b-Au
      double const localA = tmp * direction;                 // <b-Au,v>

      linearization.A.mv(direction, tmp);    //  Av
      double const localb = tmp * direction; // <Av,v>

      MyDirectionalConvexFunction<
          Dune::GlobalNonlinearity<MatrixType, VectorType>> const
      psi(localA, localb, problem.phi, u, direction);

      Interval<double> D;
      psi.subDiff(0, D);
      if (!isnan(D[1]))
        assert(D[1] <= 0);
    }

    // apply truncation to stiffness matrix and rhs
    for (size_t row = 0; row < linearization.A.N(); ++row) {
      auto const col_end = linearization.A[row].end();
      for (auto col_it = linearization.A[row].begin(); col_it != col_end;
           ++col_it) {
        int const col = col_it.index();
        for (size_t i = 0; i < col_it->N(); ++i) {
          auto const blockEnd = (*col_it)[i].end();
          for (auto blockIt = (*col_it)[i].begin(); blockIt != blockEnd;
               ++blockIt)
            if (linearization.truncation[row][i] or linearization
                    .truncation[col][blockIt.index()])
              *blockIt = 0.0;
        }
      }

      for (int j = 0; j < block_size; ++j)
        if (linearization.truncation[row][j])
          linearization.b[row][j] = 0.0;
    }

    for (int j = 0; j < block_size; ++j)
      outStream << std::setw(9) << linearization.truncation.countmasked(j);
  }

  /** \brief Constructs and returns an iterate object */
  IterateObject getIterateObject() {
    return IterateObject(parset, bisection, problem);
  }

private:
  Dune::ParameterTree const &parset;

  // problem data
  MyConvexProblemType const &problem;

  // commonly used minimization stuff
  Bisection bisection;

  mutable std::ostringstream outStream;
};

/** \brief Solves one local system using a scalar Gauss-Seidel method */
template <class MyConvexProblemTypeTEMPLATE>
class MyBlockProblem<MyConvexProblemTypeTEMPLATE>::IterateObject {
  friend class MyBlockProblem;

protected:
  /** \brief Constructor, protected so only friends can instantiate it
   * \param bisection The class used to do a scalar bisection
   * \param problem The problem including quadratic part and nonlinear part
   */
  IterateObject(Dune::ParameterTree const &parset, Bisection const &bisection,
                MyConvexProblemType const &problem)
      : problem(problem),
        bisection(bisection),
        localsteps(parset.get<size_t>("localsolver.steps")) {}

public:
  /** \brief Set the current iterate */
  void setIterate(VectorType &u) {
    this->u = u;
    return;
  }

  /** \brief Update the i-th block of the current iterate */
  void updateIterate(LocalVectorType const &ui, int i) {
    u[i] = ui;
    return;
  }

  /** \brief Minimise a local problem using a modified gradient method
   * \param[out] ui The solution
   * \param m Block number
   * \param ignore Set of degrees of freedom to leave untouched
   */
  void solveLocalProblem(
      LocalVectorType &ui, int m,
      typename Dune::BitSetVector<block_size>::const_reference ignore) {
    {
      int ic =
          block_size; // Special value that indicates nothing should be ignored
      switch (ignore.count()) {
        case 0: // Full problem
          break;
        case 1:
          for (ic = 0; ic < block_size; ++ic)
            if (ignore[ic])
              break;
          break;
        case block_size: // Ignore the whole node
          return;
        default:
          assert(false);
      }

      LocalMatrixType const *localA = nullptr;
      LocalVectorType localb(problem.f[m]);

      auto const end = problem.A[m].end();
      for (auto it = problem.A[m].begin(); it != end; ++it) {
        int const j = it.index();
        if (j == m)
          localA = &(*it); // localA = A[m][m]
        else
          Arithmetic::addProduct(localb, -1.0, *it, u[j]);
      }
      assert(localA != nullptr);

      auto const phi = problem.phi.restriction(m);
      Dune::EllipticEnergy<block_size> localJ(*localA, localb, phi, ic);
      Dune::minimise(localJ, ui, localsteps, bisection);
    }
  }

private:
  // problem data
  MyConvexProblemType const &problem;

  // commonly used minimization stuff
  Bisection bisection;

  // state data for smoothing procedure used by:
  // setIterate, updateIterate, solveLocalProblem
  VectorType u;

  size_t const localsteps;
};

#endif