// -*- tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set ts=8 sw=2 et sts=2:
#ifndef DUNE_TECTONIC_SPATIAL_SOLVING_FUNCTIONAL_HH
#define DUNE_TECTONIC_SPATIAL_SOLVING_FUNCTIONAL_HH

#include <cstddef>
#include <type_traits>

#include <dune/solvers/common/wrapownshare.hh>
#include <dune/solvers/common/copyorreference.hh>
#include <dune/solvers/common/interval.hh>
#include <dune/solvers/common/resize.hh>

#include <dune/matrix-vector/algorithm.hh>
#include <dune/matrix-vector/axpy.hh>

#include <dune/tnnmg/functionals/boxconstrainedquadraticfunctional.hh>

#include "../../utils/debugutils.hh"

template<class M, class V, class N, class L, class U, class R>
class DirectionalRestriction;

template<class M, class E, class V, class N, class L, class U, class O, class R>
class ShiftedFunctional;

template <class M, class V, class N, class L, class U, class R>
class Functional;

/** \brief Coordinate restriction of box constrained quadratic functional with nonlinearity;
 *         mainly used for presmoothing in TNNMG algorithm
 *
 *  \tparam M global matrix type
 *  \tparam V global vector type
 *  \tparam V nonlinearity type
 *  \tparam R Range type
 */
template<class M, class E, class V, class N, class L, class U, class O, class R>
class ShiftedFunctional {

public:
  using Matrix = std::decay_t<M>;
  using Eigenvalues = E;
  using Vector = std::decay_t<V>;
  using Nonlinearity = std::decay_t<N>;
  using LowerObstacle = std::decay_t<L>;
  using UpperObstacle = std::decay_t<U>;
  using Origin = std::decay_t<O>;
  using Range = R;

  template <class MM, class VV, class LL, class UU, class OO>
  ShiftedFunctional(MM&& matrix, const Eigenvalues& maxEigenvalues, VV&& linearPart, const Nonlinearity& phi,
                    LL&& lower, UU&& upper, OO&& origin) :
      quadraticPart_(std::forward<MM>(matrix)),
      maxEigenvalues_(maxEigenvalues),
      originalLinearPart_(std::forward<VV>(linearPart)),
      phi_(phi),
      originalLowerObstacle_(std::forward<LL>(lower)),
      originalUpperObstacle_(std::forward<UU>(upper)),
      origin_(std::forward<OO>(origin))
  {}

  Range operator()(const Vector& v) const
  {
    auto w = origin();
    w += v;
    if (Dune::TNNMG::checkInsideBox(w, originalLowerObstacle(), originalUpperObstacle())) {
        Vector temp;
        Dune::Solvers::resizeInitializeZero(temp,v);
        quadraticPart().umv(w, temp);
        temp *= 0.5;
        temp -= originalLinearPart();
        return temp*w + phi()(w);
      }
    return std::numeric_limits<Range>::max();

    /*Vector temp;
    Dune::Solvers::resizeInitializeZero(temp,v);
    quadraticPart().umv(w, temp);
    temp *= 0.5;
    temp -= originalLinearPart();
    return temp*w + phi()(w);*/
  }

  const Matrix& quadraticPart() const
  {
    return quadraticPart_.get();
  }

  const Vector& originalLinearPart() const
  {
    return originalLinearPart_.get();
  }

  const Origin& origin() const
  {
    return origin_.get();
  }

  void updateOrigin()
  {}

  void updateOrigin(std::size_t i)
  {}

  const LowerObstacle& originalLowerObstacle() const
  {
    return originalLowerObstacle_.get();
  }

  const UpperObstacle& originalUpperObstacle() const
  {
    return originalUpperObstacle_.get();
  }

  const auto& phi() const {
      return phi_;
  }

  const auto& maxEigenvalues() const {
    return maxEigenvalues_;
  }

protected:
  Dune::Solvers::ConstCopyOrReference<M> quadraticPart_;
  const Eigenvalues& maxEigenvalues_;

  Dune::Solvers::ConstCopyOrReference<V> originalLinearPart_;
  const Nonlinearity& phi_;

  Dune::Solvers::ConstCopyOrReference<L> originalLowerObstacle_;
  Dune::Solvers::ConstCopyOrReference<U> originalUpperObstacle_;

  Dune::Solvers::ConstCopyOrReference<O> origin_;
};

/** \brief Coordinate restriction of box constrained quadratic functional with nonlinearity;
 *         mainly used for line search step in TNNMG algorithm
 *
 *  \tparam M Global matrix type
 *  \tparam V Global vector type
 *  \tparam N nonlinearity type
 *  \tparam L Global lower obstacle type
 *  \tparam U Global upper obstacle type
 *  \tparam R Range type
 */
template<class M, class V, class N, class L, class U, class R=double>
class DirectionalRestriction :
  public Dune::TNNMG::BoxConstrainedQuadraticDirectionalRestriction<M,V,L,U,R>
{
  using Base = Dune::TNNMG::BoxConstrainedQuadraticDirectionalRestriction<M,V,L,U,R>;
  using Nonlinearity = N;
  using Interval = typename Dune::Solvers::Interval<R>;

public:
  using GlobalMatrix = typename Base::GlobalMatrix;
  using GlobalVector = typename Base::GlobalVector;
  using GlobalLowerObstacle = typename Base::GlobalLowerObstacle;
  using GlobalUpperObstacle = typename Base::GlobalUpperObstacle;

  using Matrix = typename Base::Matrix;
  using Vector = typename Base::Vector;

  DirectionalRestriction(const GlobalMatrix& matrix, const GlobalVector& linearTerm, const Nonlinearity& phi, const GlobalLowerObstacle& lower, const GlobalLowerObstacle& upper, const GlobalVector& origin, const GlobalVector& direction) :
    Base(matrix, linearTerm, lower, upper, origin, direction),
    origin_(origin),
    direction_(direction),
    phi_(phi)
  {
    phi_.directionalDomain(origin_, direction_, domain_);

    //std::cout << domain_[0] << " " << domain_[1] << "Phi domain:" << std::endl;
    //std::cout << this->defectLower_ << " " << this->defectUpper_ << "defect obstacles:" << std::endl;

    if (domain_[0] < this->defectLower_) {
       domain_[0] = this->defectLower_;
    }

    if (domain_[1] > this->defectUpper_) {
       domain_[1] = this->defectUpper_;
    }
  }

    auto subDifferential(double x) const {
      Interval D;
      GlobalVector uxv = origin_;

      Dune::MatrixVector::addProduct(uxv, x, direction_);
      phi_.directionalSubDiff(uxv, direction_, D);
      auto const Axmb = this->quadraticPart_ * x - this->linearPart_;
      D[0] += Axmb;
      D[1] += Axmb;

      return D;
    }

    auto domain() const {
      return domain_;
    }

    const GlobalVector& origin() const {
        return origin_;
    }

    const GlobalVector& direction() const {
        return direction_;
    }

protected:
  const GlobalVector& origin_;
  const GlobalVector& direction_;

  const Nonlinearity& phi_;
  Interval domain_;
};


/** \brief Box constrained quadratic functional with nonlinearity
 *         Note: call setIgnore() to set up functional in affine subspace with ignore information
 *
 *  \tparam V Vector type
 *  \tparam N Nonlinearity type
 *  \tparam L Lower obstacle type
 *  \tparam U Upper obstacle type
 *  \tparam R Range type
 */
template<class N, class V, class R>
class FirstOrderFunctional {
    using Interval = typename Dune::Solvers::Interval<R>;

public:
    using Nonlinearity = std::decay_t<N>;
    using Vector = V;
    using LowerObstacle = V;
    using UpperObstacle = V;
    using Range = R;

    using field_type = typename V::field_type;
public:
    template <class LL, class UU, class OO, class DD>
    FirstOrderFunctional(
            const Range& maxEigenvalue,
            const Range& linearPart,
            const Nonlinearity& phi,
            LL&& lower,
            UU&& upper,
            OO&& origin,
            DD&& direction) :
        quadraticPart_(maxEigenvalue),
        linearPart_(linearPart),
        lower_(std::forward<LL>(lower)),
        upper_(std::forward<UU>(upper)),
        origin_(std::forward<OO>(origin)),
        direction_(std::forward<DD>(direction)),
        phi_(phi) {

        // set defect obstacles
        defectLower_ = -std::numeric_limits<field_type>::max();
        defectUpper_ = std::numeric_limits<field_type>::max();
        Dune::TNNMG::directionalObstacles(origin_.get(), direction_.get(), lower_.get(), upper_.get(), defectLower_, defectUpper_);

        // set domain
        phi_.directionalDomain(origin_.get(), direction_.get(), domain_);

        if (domain_[0] < defectLower_) {
           domain_[0] = defectLower_;
        }

        if (domain_[1] > defectUpper_) {
           domain_[1] = defectUpper_;
        }
    }

  Range operator()(const Vector& v) const
  {
    DUNE_THROW(Dune::NotImplemented, "Evaluation of FirstOrderFunctional not implemented");
  }

  auto subDifferential(double x) const {
    Interval Di;
    Vector uxv = origin_.get();

    Dune::MatrixVector::addProduct(uxv, x, direction_.get());
    phi_.directionalSubDiff(uxv, direction_.get(), Di);

    const auto Axmb = quadraticPart_ * x - linearPart_;
    Di[0] += Axmb;
    Di[1] += Axmb;

    return Di;
  }

  const Interval& domain() const {
    return domain_;
  }

  const auto& origin() const {
      return origin_.get();
  }

  const auto& direction() const {
      return direction_.get();
  }

  const auto& lowerObstacle() const {
      return defectLower_;
  }

  const auto& upperObstacle() const {
      return defectUpper_;
  }

  const auto& quadraticPart() const {
      return quadraticPart_;
  }

  const auto& linearPart() const {
      return linearPart_;
  }
private:
  const Range quadraticPart_;
  const Range linearPart_;

  Dune::Solvers::ConstCopyOrReference<LowerObstacle> lower_;
  Dune::Solvers::ConstCopyOrReference<UpperObstacle> upper_;

  Dune::Solvers::ConstCopyOrReference<Vector> origin_;
  Dune::Solvers::ConstCopyOrReference<Vector> direction_;

  const Nonlinearity& phi_;

  Interval domain_;

  field_type defectLower_;
  field_type defectUpper_;
};

// \ToDo This should be an inline friend of ShiftedBoxConstrainedQuadraticFunctional
// but gcc-4.9.2 shipped with debian jessie does not accept
// inline friends with auto return type due to bug-59766.
// Notice, that this is fixed in gcc-4.9.3.
template<class GlobalShiftedFunctional, class Index>
auto coordinateRestriction(const GlobalShiftedFunctional& f, const Index& i)
{
  using Range = typename GlobalShiftedFunctional::Range;
  using LocalMatrix = std::decay_t<decltype(f.quadraticPart()[i][i])>;
  using LocalVector = std::decay_t<decltype(f.originalLinearPart()[i])>;
  using LocalLowerObstacle = std::decay_t<decltype(f.originalLowerObstacle()[i])>;
  using LocalUpperObstacle = std::decay_t<decltype(f.originalUpperObstacle()[i])>;

  using namespace Dune::MatrixVector;
  namespace H = Dune::Hybrid;

  const LocalMatrix* Aii_p = nullptr;

  //print(f.originalLinearPart(), "f.linearPart: ");
  //print(f.quadraticPart(), "f.quadraticPart: ");

  LocalVector ri = f.originalLinearPart()[i];
  const auto& Ai = f.quadraticPart()[i];
  sparseRangeFor(Ai, [&](auto&& Aij, auto&& j) {
      // TODO Here we must implement a wrapper to guarantee that this will work with proxy matrices!
      H::ifElse(H::equals(j, i), [&](auto&& id){
        Aii_p = id(&Aij);
      });
      Dune::TNNMG::Imp::mmv(Aij, f.origin()[j], ri, Dune::PriorityTag<1>());
  });

  //print(*Aii_p, "Aii_p:");
  //print(ri, "ri:");

  //print(f.originalLowerObstacle()[i], "lower:");
  //print(f.originalUpperObstacle()[i], "upper:");

  auto& phii = f.phi().restriction(i);

  return ShiftedFunctional<LocalMatrix&, Range, LocalVector, std::decay_t<decltype(phii)>, LocalLowerObstacle&, LocalUpperObstacle&, LocalVector&, Range>(*Aii_p, f.maxEigenvalues()[i], std::move(ri), phii, f.originalLowerObstacle()[i], f.originalUpperObstacle()[i], f.origin()[i]);
}


/** \brief Box constrained quadratic functional with nonlinearity
 *
 *  \tparam M Matrix type
 *  \tparam V Vector type
 *  \tparam L Lower obstacle type
 *  \tparam U Upper obstacle type
 *  \tparam R Range type
 */
template<class M, class V, class N, class L, class U, class R>
class Functional : public Dune::TNNMG::BoxConstrainedQuadraticFunctional<M, V, L, U, R>
{
private:
  using Base = Dune::TNNMG::BoxConstrainedQuadraticFunctional<M, V, L, U, R>;


public:
  using Nonlinearity = std::decay_t<N>;

  using Matrix = typename Base::Matrix;
  using Vector = typename Base::Vector;
  using Range = typename Base::Range;
  using LowerObstacle = typename Base::LowerObstacle;
  using UpperObstacle = typename Base::UpperObstacle;

  using Eigenvalues = std::vector<Range>;

private:
    Dune::Solvers::ConstCopyOrReference<N> phi_;
    Eigenvalues maxEigenvalues_;

public:

  template <class MM, class VV, class NN, class LL, class UU>
  Functional(
          MM&& matrix,
          VV&& linearPart,
          NN&& phi,
          LL&& lower,
          UU&& upper) :
    Base(std::forward<MM>(matrix), std::forward<VV>(linearPart), std::forward<LL>(lower), std::forward<UU>(upper)),
    phi_(std::forward<NN>(phi)),
    maxEigenvalues_(this->linearPart().size())
  {
     for (size_t i=0; i<this->quadraticPart().N(); ++i) {
       typename Vector::block_type eigenvalues;
       Dune::FMatrixHelp::eigenValues(this->quadraticPart()[i][i], eigenvalues);
       maxEigenvalues_[i] =
           *std::max_element(std::begin(eigenvalues), std::end(eigenvalues));
     }
  }

  const Nonlinearity& phi() const {
    return phi_.get();
  }

  const auto& maxEigenvalues() const {
    return maxEigenvalues_;
  }

  Range operator()(const Vector& v) const
  { 
    //print(v, "v:");
    //print(this->lowerObstacle(), "lower: ");
    //print(this->upperObstacle(), "upper: ");
    //std::cout << Dune::TNNMG::checkInsideBox(v, this->lowerObstacle(), this->upperObstacle()) << " " << Dune::TNNMG::QuadraticFunctional<M,V,R>::operator()(v) << " " << phi_.get().operator()(v) << std::endl;
    if (Dune::TNNMG::checkInsideBox(v, this->lowerObstacle(), this->upperObstacle()))
      return Dune::TNNMG::QuadraticFunctional<M,V,R>::operator()(v) + phi_.get()(v);
    return std::numeric_limits<Range>::max();
  }

  friend auto directionalRestriction(const Functional& f, const Vector& origin, const Vector& direction)
    -> DirectionalRestriction<Matrix, Vector, Nonlinearity, LowerObstacle, UpperObstacle, Range>
  {
    return DirectionalRestriction<Matrix, Vector, Nonlinearity, LowerObstacle, UpperObstacle, Range>(f.quadraticPart(), f.linearPart(), f.phi(), f.lowerObstacle(), f.upperObstacle(), origin, direction);
  }

  friend auto shift(const Functional& f, const Vector& origin)
    -> ShiftedFunctional<Matrix&, Eigenvalues, Vector&, Nonlinearity&, LowerObstacle&, UpperObstacle&, Vector&, Range>
  {
    return ShiftedFunctional<Matrix&, Eigenvalues, Vector&, Nonlinearity&, LowerObstacle&, UpperObstacle&, Vector&, Range>(f.quadraticPart(), f.maxEigenvalues(), f.linearPart(), f.phi(), f.lowerObstacle(), f.upperObstacle(), origin);
  }
};


#endif