template <class VectorType, class MatrixType, class FunctionType, int dim>
class TimeSteppingScheme {
public:
  TimeSteppingScheme(VectorType const &_ell, MatrixType const &_A,
                     VectorType const &_u_old, VectorType const *_u_old_old,
                     Dune::BitSetVector<dim> const &_dirichletNodes,
                     FunctionType const &_dirichletFunction, double _time)
      : ell(_ell),
        A(_A),
        u_old(_u_old),
        u_old_old(_u_old_old),
        dirichletNodes(_dirichletNodes),
        dirichletFunction(_dirichletFunction),
        time(_time) {}

  virtual ~TimeSteppingScheme() {}

  void virtual setup(VectorType &problem_rhs, VectorType &problem_iterate,
                     MatrixType &problem_A) const = 0;

  void virtual extractSolution(VectorType const &problem_iterate,
                               VectorType &solution) const = 0;

  void virtual extractDifference(VectorType const &problem_iterate,
                                 VectorType &velocity) const = 0;

protected:
  VectorType const &ell;
  MatrixType const &A;
  VectorType const &u_old;
  VectorType const *u_old_old;
  Dune::BitSetVector<dim> const &dirichletNodes;
  FunctionType const &dirichletFunction;
  double time;
};

// Implicit Euler: Solve the problem
//
// a(Delta u_new, v - Delta u_new) + j_tau(v) - j_tau(Delta u_new)
// >= l(w - Delta u_new) - a(u_new, v - Delta u_new)
template <class VectorType, class MatrixType, class FunctionType, int dim>
class ImplicitEuler
    : public TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> {
public:
  // Work arond the fact that nobody implements constructor inheritance
  template <class... Args>
  ImplicitEuler(Args &&... args)
      : TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim>(args...) {
  }

  void virtual setup(VectorType &problem_rhs, VectorType &problem_iterate,
                     MatrixType &problem_A) const {
    problem_A = this->A;
    problem_rhs = this->ell;
    problem_A.mmv(this->u_old, problem_rhs);

    problem_iterate = this->u_old;
    if (this->u_old_old)
      problem_iterate -= *this->u_old_old;

    for (size_t i = 0; i < this->dirichletNodes.size(); ++i)
      if (this->dirichletNodes[i].count() == dim) {
        double val;
        this->dirichletFunction.evaluate(this->time, val);
        problem_iterate[i] = 0; // Everything prescribed
        problem_iterate[i][0] =
            val - this->u_old[i][0]; // Time-dependent X direction
      } else if (this->dirichletNodes[i][1])
        problem_iterate[i][1] = 0; // Y direction described
  }

  void virtual extractSolution(VectorType const &problem_iterate,
                               VectorType &solution) const {
    solution = this->u_old;
    solution += problem_iterate;
  }

  void virtual extractDifference(VectorType const &problem_iterate,
                                 VectorType &velocity) const {
    velocity = problem_iterate;
  }
};

// two-Stage implicit algorithm
template <class VectorType, class MatrixType, class FunctionType, int dim>
class ImplicitTwoStep
    : public TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> {
public:
  // Work arond the fact that nobody implements constructor inheritance
  template <class... Args>
  ImplicitTwoStep(Args &&... args)
      : TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim>(args...) {
  }

  void virtual setup(VectorType &problem_rhs, VectorType &problem_iterate,
                     MatrixType &problem_A) const {
    problem_A = this->A;
    problem_A /= 1.5;
    problem_rhs = this->ell;
    problem_A.usmv(-2, this->u_old, problem_rhs);
    problem_A.usmv(.5, *this->u_old_old, problem_rhs);

    // The finite difference makes a good start
    problem_iterate = this->u_old;
    problem_iterate -= *this->u_old_old;

    for (size_t i = 0; i < this->dirichletNodes.size(); ++i)
      if (this->dirichletNodes[i].count() == dim) {
        double val;
        this->dirichletFunction.evaluate(this->time, val);
        problem_iterate[i] = 0;
        problem_iterate[i].axpy(-2, this->u_old[i]);
        problem_iterate[i].axpy(.5, (*this->u_old_old)[i]);
        problem_iterate[i][0] =
            1.5 * val - 2 * this->u_old[i][0] + .5 * (*this->u_old_old)[i][0];
      } else if (this->dirichletNodes[i][1])
        // Y direction described
        problem_iterate[i][1] =
            -2 * this->u_old[i][1] + .5 * (*this->u_old_old)[i][1];
  }

  void virtual extractSolution(VectorType const &problem_iterate,
                               VectorType &solution) const {
    solution = problem_iterate;
    solution.axpy(2, this->u_old);
    solution.axpy(-.5, *this->u_old_old);
    solution *= 2.0 / 3.0;

    // Check if we split correctly
    {
      VectorType test = problem_iterate;
      test.axpy(-1.5, solution);
      test.axpy(+2, this->u_old);
      test.axpy(-.5, *this->u_old_old);
      assert(test.two_norm() < 1e-10);
    }
  }

  void virtual extractDifference(VectorType const &problem_iterate,
                                 VectorType &velocity) const {
    velocity = problem_iterate;
  }
};