// Based on dune/tnnmg/problem-classes/blocknonlineargsproblem.hh

#ifndef MY_BLOCK_PROBLEM_HH
#define MY_BLOCK_PROBLEM_HH

// #include <dune/common/parametertree.hh>
#include <dune/common/bitsetvector.hh>

#include <dune/tnnmg/problem-classes/bisection.hh>
#include <dune/tnnmg/problem-classes/nonlinearity.hh>
#include <dune/tnnmg/problem-classes/onedconvexfunction.hh>

#include "mynonlinearity.hh"
#include "nicefunction.hh"
#include "samplefunctional.hh"

/** \brief Base class for problems where each block can be solved with a
 * modified gradient method */
template <class MyConvexProblemTypeTEMPLATE> class MyBlockProblem {
public:
  typedef MyConvexProblemTypeTEMPLATE MyConvexProblemType;
  typedef typename MyConvexProblemType::NonlinearityType NonlinearityType;
  typedef typename MyConvexProblemType::VectorType VectorType;
  typedef typename MyConvexProblemType::MatrixType MatrixType;
  typedef typename MyConvexProblemType::LocalVectorType LocalVectorType;
  typedef typename MyConvexProblemType::LocalMatrixType LocalMatrixType;

  static int const block_size = MyConvexProblemType::block_size;

  /** \brief Solves one local system using a modified gradient method */
  class IterateObject;

  MyBlockProblem(MyConvexProblemType& problem) : problem(problem) {
    bisection = Bisection(0.0, 1.0, 1e-12, true, 0);
  };

  /** \brief Constructs and returns an iterate object */
  IterateObject getIterateObject() { return IterateObject(bisection, problem); }

private:
  // problem data
  MyConvexProblemType& problem;

  // commonly used minimization stuff
  Bisection bisection;
};

/** \brief Solves one local system using a scalar Gauss-Seidel method */
template <class MyConvexProblemTypeTEMPLATE>
class MyBlockProblem<MyConvexProblemTypeTEMPLATE>::IterateObject {
  friend class MyBlockProblem;

protected:
  /** \brief Constructor, protected so only friends can instantiate it
   * \param bisection The class used to do a scalar bisection
   * \param problem The problem including quadratic part and nonlinear part
   */
  IterateObject(const Bisection& bisection, MyConvexProblemType& problem)
      : problem(problem), bisection(bisection) {}

public:
  /** \brief Set the current iterate */
  void setIterate(VectorType& u) {
    this->u = u;
    return;
  }

  /** \brief Update the i-th block of the current iterate */
  void updateIterate(const LocalVectorType& ui, int i) {
    u[i] = ui;
    return;
  }

  /** \brief Minimise a local problem using a modified gradient method
   * \param[out] ui The solution
   * \param m Block number
   * \param ignore Set of degrees of freedom to leave untouched
   */
  void solveLocalProblem(
      LocalVectorType& ui, int m,
      const typename Dune::BitSetVector<block_size>::const_reference ignore) {
    // Note: ignore is currently ignored (what's it used for anyway?)
    {
      LocalMatrixType const* localA = NULL;
      LocalVectorType localb(problem.f[m]);

      typename MatrixType::row_type::ConstIterator it;
      typename MatrixType::row_type::ConstIterator end = problem.A[m].end();
      for (it = problem.A[m].begin(); it != end; ++it) {
        int const j = it.index();
        if (j == m)
          localA = &(*it); // localA = A[m][m]
        else
          it->mmv(u[j], localb); // localb -= A[m][j] * u[j]
      }
      assert(localA != NULL);

      // FIXME: Hardcoding a fixed function here for now
      Dune::TrivialFunction func;
      Dune::MyNonlinearity<block_size> phi(func);
      Dune::SampleFunctional<block_size> localJ(*localA, localb, phi);

      LocalVectorType correction;
      Dune::minimise(localJ, ui, 10, bisection); // FIXME: hardcoded value
    }
  }

private:
  // problem data
  MyConvexProblemType& problem;

  // commonly used minimization stuff
  Bisection bisection;

  // state data for smoothing procedure used by:
  // setIterate, updateIterate, solveLocalProblem
  VectorType u;
};

#endif