// Copied from dune/tnnmg/problem-classes/directionalconvexfunction.hh
// Allows phi to be const

#ifndef MY_DIRECTIONAL_CONVEX_FUNCTION_HH
#define MY_DIRECTIONAL_CONVEX_FUNCTION_HH

#include <dune/fufem/interval.hh>

template <class NonlinearityType> class MyDirectionalConvexFunction {
public:
  using VectorType = typename NonlinearityType::VectorType;
  using MatrixType = typename NonlinearityType::MatrixType;

  MyDirectionalConvexFunction(double A, double b, NonlinearityType const &phi,
                              VectorType const &u, VectorType const &v)
      : A(A), b(b), phi(phi), u(u), v(v) {
    phi.directionalDomain(u, v, dom);
  }

  /* Just for debugging */
  double operator()(double x) const {
    VectorType tmp = v;
    tmp *= x;
    return (0.5 * A * x * x) - (b * x) + phi(tmp);
  }

  double quadraticPart() const { return A; }

  double linearPart() const { return b; }

  void subDiff(double x, Interval<double> &D) const {
    VectorType tmp = u;
    tmp.axpy(x, v);
    phi.directionalSubDiff(tmp, v, D);
    D[0] += A * x - b;
    D[1] += A * x - b;
  }

  void domain(Interval<double> &domain) const {
    domain[0] = this->dom[0];
    domain[1] = this->dom[1];
  }

  double A;
  double b;

private:
  NonlinearityType const &phi;
  VectorType const &u;
  VectorType const &v;

  Interval<double> dom;
};

#endif