// Copied from dune/tnnmg/problem-classes/directionalconvexfunction.hh

#ifndef MY_DIRECTIONAL_CONVEX_FUNCTION_HH
#define MY_DIRECTIONAL_CONVEX_FUNCTION_HH

#include <dune/fufem/interval.hh>

template <class NonlinearityType> class MyDirectionalConvexFunction {
public:
  typedef typename NonlinearityType::VectorType VectorType;
  typedef typename NonlinearityType::MatrixType MatrixType;

  MyDirectionalConvexFunction(double A, double b, NonlinearityType const& phi,
                              const VectorType& u, const VectorType& v)
      : A(A), b(b), phi_(phi), u_(u), v_(v), temp_u_(u) {
    phi_.directionalDomain(u_, v_, dom_);
  }

  double quadraticPart() const { return A; }

  double linearPart() const { return b; }

  void subDiff(double x, Interval<double>& D) const {
    temp_u_ = u_;
    temp_u_.axpy(x, v_);
    phi_.directionalSubDiff(temp_u_, v_, D);
    D[0] += A * x - b;
    D[1] += A * x - b;
  }

  void domain(Interval<double>& domain) const {
    domain[0] = this->dom_[0];
    domain[1] = this->dom_[1];
  }

  double A;
  double b;

private:
  NonlinearityType const& phi_;
  const VectorType& u_;
  const VectorType& v_;

  mutable VectorType temp_u_;
  Interval<double> dom_;
};

#endif