#ifndef TEST_GRADIENT_METHOD_NICEFUNCTION_HH
#define TEST_GRADIENT_METHOD_NICEFUNCTION_HH

#include <cassert>

#include <dune/tectonic/nicefunction.hh>

namespace Dune {
class MyFunction : public NiceFunction {
  double virtual second_deriv(double) const {
    assert(false);
    return 0;
  }

  double virtual regularity(double) const {
    assert(false);
    return 0;
  }
};

class LinearFunction : public MyFunction {
public:
  LinearFunction(double a) : coefficient(a) {}

  void virtual evaluate(double const &x, double &y) const {
    y = coefficient * x;
  }

  double virtual leftDifferential(double s) const { return coefficient; }

  double virtual rightDifferential(double s) const { return coefficient; }

  double virtual second_deriv(double) const { return 0; }

  double virtual regularity(double s) const { return 0; }

private:
  double const coefficient;
};

template <int slope> class SampleFunction : public MyFunction {
public:
  void virtual evaluate(double const &x, double &y) const {
    y = (x < 1) ? x : (slope * (x - 1) + 1);
  }

  double virtual leftDifferential(double s) const {
    return (s <= 1) ? 1 : slope;
  }

  double virtual rightDifferential(double s) const {
    return (s < 1) ? 1 : slope;
  }
};

class SteepFunction : public MyFunction {
public:
  void virtual evaluate(double const &x, double &y) const { y = 100 * x; }

  double virtual leftDifferential(double s) const { return 100; }

  double virtual rightDifferential(double s) const { return 100; }
};

// slope in [n-1,n] is n
class HorribleFunction : public MyFunction {
public:
  void virtual evaluate(double const &x, double &y) const {
    double const fl = floor(x);
    double const sum = fl * (fl + 1) / 2;
    y = sum + (fl + 1) * (x - fl);
  }

  double virtual leftDifferential(double x) const {
    double const fl = floor(x);
    if (x - fl < 1e-14)
      return fl;
    else
      return fl + 1;
  }

  double virtual rightDifferential(double x) const {
    double const c = ceil(x);
    if (c - x < 1e-14)
      return c + 1;
    else
      return c;
  }
};

// slope in [n-1,n] is log(n+1)
class HorribleFunctionLogarithmic : public MyFunction {
public:
  void virtual evaluate(double const &x, double &y) const {
    y = 0;
    size_t const fl = floor(x);
    for (size_t i = 1; i <= fl;)
      y += std::log(
          ++i); // factorials grow to fast so we compute this incrementally

    y += std::log(fl + 2) * (x - fl);
  }

  double virtual leftDifferential(double x) const {
    double const fl = floor(x);
    if (x - fl < 1e-14)
      return std::log(fl + 1);
    else
      return std::log(fl + 2);
  }

  double virtual rightDifferential(double x) const {
    double const c = ceil(x);
    if (c - x < 1e-14)
      return std::log(c + 2);
    else
      return std::log(c + 1);
  }
};
}
#endif