#ifndef NICE_FUNCTION_HH
#define NICE_FUNCTION_HH

#include <algorithm>
#include <cassert>
#include <cmath>
#include <limits>

#include <dune/common/exceptions.hh>
#include <dune/common/function.hh>

#include "frictiondata.hh"

namespace Dune {
class NiceFunction {
protected:
  NiceFunction(std::vector<double> const &_kinks) : kinks(_kinks) {}

private:
  std::vector<double> const kinks;

public:
  std::vector<double> const &get_kinks() const { return kinks; }

  NiceFunction() : kinks() {}

  virtual ~NiceFunction() {}

  double virtual leftDifferential(double s) const = 0;
  double virtual rightDifferential(double s) const = 0;

  double virtual second_deriv(double x) const = 0;

  double virtual regularity(double s) const = 0;

  // Whether H(|.|) is smooth at zero
  bool virtual smoothesNorm() const { return false; }

  double virtual evaluate(double x) const {
    DUNE_THROW(NotImplemented, "evaluation not implemented");
  }
};

// phi(V) = V log(V/V_m) - V + V_m  if V >= V_m
//        = 0                       otherwise
// with V_m = V_0 exp(-K/a),
// i.e.   K = -a log(V_m / V_0) = mu_0 + b log(V_0 / L) + b alpha
class RuinaFunction : public NiceFunction {
public:
  RuinaFunction(double coefficient, FrictionData const &fd, double state)
      : NiceFunction(),
        coefficientProduct(coefficient * fd.a * fd.normalStress),
        // state is assumed to be logarithmic
        V_m(fd.V0 *
            std::exp(-(fd.mu0 + fd.b * (state + std::log(fd.V0 / fd.L))) /
                     fd.a))
        // We could also compute V_m as
        // V_0 * std::exp(-(mu_0 + b * state)/a)
        //     * std::pow(V_0 / L, -b/a)
        // which would avoid the std::exp(std::log())
  {}

  double virtual evaluate(double V) const {
    assert(V >= 0);
    if (V <= V_m)
      return 0;

    // V log(V/V_m) - V + V_m
    return coefficientProduct * (V * std::log(V / V_m) - V + V_m);
  }

  // log(V/V_m)  if V >= V_0
  // 0           otherwise
  double virtual leftDifferential(double V) const {
    assert(V >= 0);
    if (V <= V_m)
      return 0;

    return coefficientProduct * std::log(V / V_m);
  }

  double virtual rightDifferential(double V) const {
    return leftDifferential(V);
  }

  // 1/V        if V >  V_0
  // undefined  if V == V_0
  // 0          if V <  V_0
  double virtual second_deriv(double V) const {
    assert(V >= 0);
    if (V <= V_m)
      return 0;

    return coefficientProduct / V;
  }

  double virtual regularity(double V) const {
    assert(V >= 0);
    // TODO: Make this controllable
    if (std::abs(V - V_m) < 1e-14)
      return std::numeric_limits<double>::infinity();

    return std::abs(second_deriv(V));
  }

  bool virtual smoothesNorm() const { return true; }

private:
  double const coefficientProduct;
  double const V_m;
};

class TrivialFunction : public NiceFunction {
public:
  TrivialFunction() : NiceFunction() {}

  double virtual evaluate(double) const { return 0; }

  double virtual leftDifferential(double) const { return 0; }

  double virtual rightDifferential(double) const { return 0; }

  double virtual second_deriv(double) const { return 0; }

  double virtual regularity(double) const { return 0; }

  bool virtual smoothesNorm() const { return true; }
};
}
#endif