// -*- tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set ts=8 sw=2 et sts=2:
#ifndef DUNE_TECTONIC_SPATIAL_SOLVING_FUNCTIONAL_HH
#define DUNE_TECTONIC_SPATIAL_SOLVING_FUNCTIONAL_HH

#include <cstddef>
#include <type_traits>

#include <dune/solvers/common/wrapownshare.hh>
#include <dune/solvers/common/copyorreference.hh>
#include <dune/solvers/common/interval.hh>
#include <dune/solvers/common/resize.hh>

#include <dune/matrix-vector/algorithm.hh>
#include <dune/matrix-vector/axpy.hh>

#include <dune/tnnmg/functionals/boxconstrainedquadraticfunctional.hh>

/** \brief Coordinate restriction of box constrained quadratic functional with nonlinearity;
 *         mainly used for presmoothing in TNNMG algorithm
 *
 *  \tparam M global matrix type
 *  \tparam V global vector type
 *  \tparam V nonlinearity type
 *  \tparam R Range type
 */
template<class M, class V, class N, class R>
class ShiftedFunctional : public Dune::TNNMG::ShiftedBoxConstrainedQuadraticFunctional<M,V,R> {
  using Base = Dune::TNNMG::ShiftedBoxConstrainedQuadraticFunctional<M,V,R>;

public:
  using Matrix = M;
  using Vector = V;
  using Nonlinearity = N;
  using LowerObstacle = Vector;
  using UpperObstacle = Vector;
  using Range = R;

  ShiftedFunctional(const Matrix& quadraticPart, const Vector& linearPart, const Nonlinearity& phi, const LowerObstacle& lowerObstacle, const UpperObstacle& upperObstacle, const Vector& origin) :
    Base(quadraticPart, linearPart, lowerObstacle, upperObstacle, origin),
    phi_(phi)
  {}

  Range operator()(const Vector& v) const
  {
    auto temp = Base::Base::origin();
    temp += v;
    if (checkInsideBox(temp, this->originalLowerObstacle(), this->originalUpperObstacle()))
      return Base::Base::operator()(v) + phi_.operator()(temp);
    return std::numeric_limits<Range>::max();
  }

  const auto& phi() const {
      return phi_;
  }

protected:
  const Nonlinearity& phi_;
};





/** \brief Coordinate restriction of box constrained quadratic functional with nonlinearity;
 *         mainly used for line search step in TNNMG algorithm
 *
 *  \tparam M Global matrix type
 *  \tparam V Global vector type
 *  \tparam N nonlinearity type
 *  \tparam L Global lower obstacle type
 *  \tparam U Global upper obstacle type
 *  \tparam R Range type
 */
template<class M, class V, class N, class L, class U, class R=double>
class DirectionalRestriction :
  public Dune::TNNMG::BoxConstrainedQuadraticDirectionalRestriction<M,V,L,U,R>
{
  using Base = Dune::TNNMG::BoxConstrainedQuadraticDirectionalRestriction<M,V,L,U,R>;
  using Nonlinearity = N;
  using Interval = typename Dune::Solvers::Interval<R>;

public:
  using GlobalMatrix = typename Base::GlobalMatrix;
  using GlobalVector = typename Base::GlobalVector;
  using GlobalLowerObstacle = typename Base::GlobalLowerObstacle;
  using GlobalUpperObstacle = typename Base::GlobalUpperObstacle;

  using Matrix = typename Base::Matrix;
  using Vector = typename Base::Vector;

  DirectionalRestriction(const GlobalMatrix& matrix, const GlobalVector& linearTerm, const Nonlinearity& phi, const GlobalLowerObstacle& lower, const GlobalLowerObstacle& upper, const GlobalVector& origin, const GlobalVector& direction) :
    Base(matrix, linearTerm, lower, upper, origin, direction),
    origin_(origin),
    direction_(direction),
    phi_(phi)
  {
    phi_.directionalDomain(origin_, direction_, domain_);

    std::cout << domain_[0] << " " << domain_[1] << "Phi domain:" << std::endl;
    std::cout << this->defectLower_ << " " << this->defectUpper_ << "defect obstacles:" << std::endl;
    if (domain_[0] < this->defectLower_) {
       domain_[0] = this->defectLower_;
    }

    if (domain_[1] > this->defectUpper_) {
       domain_[1] = this->defectUpper_;
    }
  }

    auto subDifferential(double x) const {
      Interval D;
      GlobalVector uxv = origin_;

      Dune::MatrixVector::addProduct(uxv, x, direction_);
      phi_.directionalSubDiff(uxv, direction_, D);
      auto const Axmb = this->quadraticPart_ * x - this->linearPart_;
      D[0] += Axmb;
      D[1] += Axmb;

      return D;
    }

    auto domain() const {
      return domain_;
    }

    const GlobalVector& origin() const {
        return origin_;
    }

    const GlobalVector& direction() const {
        return direction_;
    }

protected:
  const GlobalVector& origin_;
  const GlobalVector& direction_;

  const Nonlinearity& phi_;
  Interval domain_;
};

/** \brief Box constrained quadratic functional with nonlinearity
 *         Note: call setIgnore() to set up functional in affine subspace with ignore information
 *
 *  \tparam V Vector type
 *  \tparam N Nonlinearity type
 *  \tparam L Lower obstacle type
 *  \tparam U Upper obstacle type
 *  \tparam R Range type
 */
template<class V, class N, class L, class U, class O, class D, class R>
class FirstOrderModelFunctional {
    using Interval = typename Dune::Solvers::Interval<R>;

public:
    using Vector = std::decay_t<V>;
    using Nonlinearity = std::decay_t<N>;
    using LowerObstacle = std::decay_t<L>;
    using UpperObstacle = std::decay_t<U>;
    using Range = R;

    using field_type = typename Vector::field_type;

private:
    void init() {
        auto origin = origin_.get();
        auto direction = direction_.get();

        // set defect obstacles
        defectLower_ = -std::numeric_limits<field_type>::max();
        defectUpper_ = std::numeric_limits<field_type>::max();
        Dune::TNNMG::directionalObstacles(origin, direction, lower_.get(), upper_.get(), defectLower_, defectUpper_);

        // set domain
        phi_.get().directionalDomain(origin, direction, domain_);

        if (domain_[0] < defectLower_) {
           domain_[0] = defectLower_;
        }

        if (domain_[1] > defectUpper_) {
           domain_[1] = defectUpper_;
        }

        // set quadratic and linear parts of functional
        quadraticPart_ = direction*direction;
        quadraticPart_ *= maxEigenvalue_;

        linearPart_ = linearTerm_.get()*direction - maxEigenvalue_*(direction*origin);
    }

public:
    template <class MM, class VV1, class NN, class LL, class UU, class VV2, class VV3>
    FirstOrderModelFunctional(
            const MM& matrix,
            VV1&& linearTerm,
            NN&& phi,
            LL&& lower,
            UU&& upper,
            VV2&& origin,
            VV3&& direction) :
        linearTerm_(std::forward<VV1>(linearTerm)),
        lower_(std::forward<LL>(lower)),
        upper_(std::forward<UU>(upper)),
        origin_(std::forward<VV2>(origin)),
        direction_(std::forward<VV3>(direction)),
        phi_(std::forward<NN>(phi)) {

        // set maxEigenvalue_ from matrix
        Vector eigenvalues;
        Dune::FMatrixHelp::eigenValues(matrix, eigenvalues);
        maxEigenvalue_ = *std::max_element(std::begin(eigenvalues), std::end(eigenvalues));

        init();
    }

    template <class VV1, class NN, class LL, class UU, class VV2, class VV3>
    FirstOrderModelFunctional(
            const Range& maxEigenvalue,
            VV1&& linearTerm,
            NN&& phi,
            LL&& lower,
            UU&& upper,
            VV2&& origin,
            VV3&& direction) :
        linearTerm_(std::forward<VV1>(linearTerm)),
        lower_(std::forward<LL>(lower)),
        upper_(std::forward<UU>(upper)),
        origin_(std::forward<VV2>(origin)),
        direction_(std::forward<VV3>(direction)),
        phi_(std::forward<NN>(phi)),
        maxEigenvalue_(maxEigenvalue) {

        init();
    }

  template <class BitVector>
  auto getIgnoreFunctional(const BitVector& ignore) const {
      Vector direction = direction_.get();
      Vector origin = origin_.get();
     // Dune::Solvers::resizeInitializeZero(direction, linearPart());
      for (size_t j = 0; j < ignore.size(); ++j)
        if (ignore[j])
          direction[j] = 0; // makes sure result remains in subspace after correction
        else
          origin[j] = 0; // shift affine offset

      return FirstOrderModelFunctional<Vector, Nonlinearity&, LowerObstacle, UpperObstacle, Vector, Vector, Range>(maxEigenvalue_, linearTerm_, phi_, lower_, upper_, std::move(origin), std::move(direction));
  }

  Range operator()(const Vector& v) const
  {
    DUNE_THROW(Dune::NotImplemented, "Evaluation of FirstOrderModelFunctional not implemented");
  }

  auto subDifferential(double x) const {
    Interval Di;
    Vector uxv = origin_.get();

    Dune::MatrixVector::addProduct(uxv, x, direction_.get());
    phi_.get().directionalSubDiff(uxv, direction_.get(), Di);

    const auto Axmb = quadraticPart_ * x - linearPart_;
    Di[0] += Axmb;
    Di[1] += Axmb;

    return Di;
  }

  const Interval& domain() const {
    return domain_;
  }

  const Vector& origin() const {
      return origin_.get();
  }

  const Vector& direction() const {
      return direction_.get();
  }

  const field_type& defectLower() const {
      return defectLower_;
  }

  const field_type& defectUpper() const {
      return defectUpper_;
  }

private:
  Dune::Solvers::ConstCopyOrReference<V> linearTerm_;
  Dune::Solvers::ConstCopyOrReference<L> lower_;
  Dune::Solvers::ConstCopyOrReference<U> upper_;

  Dune::Solvers::ConstCopyOrReference<O> origin_;
  Dune::Solvers::ConstCopyOrReference<D> direction_;

  Dune::Solvers::ConstCopyOrReference<N> phi_;

  Interval domain_;

  Range maxEigenvalue_;

  field_type quadraticPart_;
  field_type linearPart_;

  field_type defectLower_;
  field_type defectUpper_;
};

// \ToDo This should be an inline friend of ShiftedBoxConstrainedQuadraticFunctional
// but gcc-4.9.2 shipped with debian jessie does not accept
// inline friends with auto return type due to bug-59766.
// Notice, that this is fixed in gcc-4.9.3.
template<class Index, class M, class V, class Nonlinearity, class R>
auto coordinateRestriction(const ShiftedFunctional<M, V, Nonlinearity, R>& f, const Index& i)
{
  using Range = R;
  using LocalMatrix = std::decay_t<decltype(f.quadraticPart()[i][i])>;
  using LocalVector = std::decay_t<decltype(f.originalLinearPart()[i])>;
  using LocalLowerObstacle = std::decay_t<decltype(f.originalLowerObstacle()[i])>;
  using LocalUpperObstacle = std::decay_t<decltype(f.originalUpperObstacle()[i])>;

  using namespace Dune::MatrixVector;
  namespace H = Dune::Hybrid;

  const LocalMatrix* Aii_p = nullptr;

  LocalVector ri = f.originalLinearPart()[i];
  const auto& Ai = f.quadraticPart()[i];
  sparseRangeFor(Ai, [&](auto&& Aij, auto&& j) {
      // TODO Here we must implement a wrapper to guarantee that this will work with proxy matrices!
      H::ifElse(H::equals(j, i), [&](auto&& id){
        Aii_p = id(&Aij);
      });
      Dune::TNNMG::Imp::mmv(Aij, f.origin()[j], ri, Dune::PriorityTag<1>());
  });

  LocalLowerObstacle dli = f.originalLowerObstacle()[i];
  LocalUpperObstacle dui = f.originalUpperObstacle()[i];
  dli -= f.origin()[i];
  dui -= f.origin()[i];

  auto&& phii = f.phi().restriction(i);

  auto v = ri;
  double const vnorm = v.two_norm();
  if (vnorm > 1.0)
        v /= vnorm;

  return FirstOrderModelFunctional<LocalVector, decltype(phii), LocalLowerObstacle, LocalUpperObstacle, LocalVector, LocalVector, Range>(*Aii_p, std::move(ri), std::move(phii), std::move(dli), std::move(dui), std::move(f.origin()[i]), std::move(v));
}


/** \brief Box constrained quadratic functional with nonlinearity
 *
 *  \tparam M Matrix type
 *  \tparam V Vector type
 *  \tparam L Lower obstacle type
 *  \tparam U Upper obstacle type
 *  \tparam R Range type
 */
template<class M, class V, class N, class L, class U, class R>
class Functional : public Dune::TNNMG::BoxConstrainedQuadraticFunctional<M, V, L, U, R>
{
private:
  using Base = Dune::TNNMG::BoxConstrainedQuadraticFunctional<M, V, L, U, R>;


public:
  using Nonlinearity = std::decay_t<N>;

  using Matrix = typename Base::Matrix;
  using Vector = typename Base::Vector;
  using Range = typename Base::Range;
  using LowerObstacle = typename Base::LowerObstacle;
  using UpperObstacle = typename Base::UpperObstacle;

private:
    Dune::Solvers::ConstCopyOrReference<N> phi_;

public:

  template <class MM, class VV, class NN, class LL, class UU>
  Functional(
          MM&& matrix,
          VV&& linearPart,
          NN&& phi,
          LL& lower,
          UU& upper) :
    Base(std::forward<MM>(matrix), std::forward<VV>(linearPart), std::forward<LL>(lower), std::forward<UU>(upper)),
    phi_(std::forward<NN>(phi))
  {}

  const Nonlinearity& phi() const {
    return phi_.get();
  }

  Range operator()(const Vector& v) const
  { 
    if (Dune::TNNMG::checkInsideBox(v, this->lower_.get(), this->upper_.get()))
      return Base::Base::operator()(v) + phi_.get().operator()(v);
    return std::numeric_limits<Range>::max();
  }

  friend auto directionalRestriction(const Functional& f, const Vector& origin, const Vector& direction)
    -> DirectionalRestriction<Matrix, Vector, Nonlinearity, LowerObstacle, UpperObstacle, Range>
  {
    return DirectionalRestriction<Matrix, Vector, Nonlinearity, LowerObstacle, UpperObstacle, Range>(f.quadraticPart(), f.linearPart(), f.phi(), f.lowerObstacle(), f.upperObstacle(), origin, direction);
  }

  friend auto shift(const Functional& f, const Vector& origin)
    -> ShiftedFunctional<Matrix, Vector, Nonlinearity, Range>
  {
    return ShiftedFunctional<Matrix, Vector, Nonlinearity, Range>(f.quadraticPart(), f.linearPart(), f.phi(), f.lowerObstacle(), f.upperObstacle(), origin);
  }
};


#endif