#ifndef SRC_TIMESTEPPING_HH
#define SRC_TIMESTEPPING_HH

#include <memory>

#include <dune/common/bitsetvector.hh>

#include "enums.hh"
#include "matrices.hh"

template <class Vector, class Matrix, class Function, size_t dim>
class TimeSteppingScheme {
protected:
  TimeSteppingScheme(Matrices<Matrix> const &_matrices,
                     Vector const &_u_initial, Vector const &_v_initial,
                     Vector const &_a_initial,
                     Dune::BitSetVector<dim> const &_dirichletNodes,
                     Function const &_dirichletFunction);

public:
  void nextTimeStep();
  void virtual setup(Vector const &ell, double _tau, double relativeTime,
                     Vector &rhs, Vector &iterate, Matrix &AB) = 0;

  void virtual postProcess(Vector const &iterate) = 0;
  void extractDisplacement(Vector &displacement) const;
  void extractVelocity(Vector &velocity) const;
  void extractOldVelocity(Vector &velocity) const;

  std::shared_ptr<TimeSteppingScheme<Vector, Matrix, Function,
                                     dim>> virtual clone() const = 0;

protected:
  Matrices<Matrix> const &matrices;
  Vector u, v, a;
  Dune::BitSetVector<dim> const &dirichletNodes;
  Function const &dirichletFunction;
  double dirichletValue;

  Vector u_o, v_o, a_o;
  double tau;

  bool postProcessCalled = true;
};

#include "timestepping/newmark.hh"
#include "timestepping/backward_euler.hh"

template <class Vector, class Matrix, class Function, int dimension>
std::shared_ptr<TimeSteppingScheme<Vector, Matrix, Function, dimension>>
initTimeStepper(Config::scheme scheme,
                Function const &velocityDirichletFunction,
                Dune::BitSetVector<dimension> const &velocityDirichletNodes,
                Matrices<Matrix> const &matrices, Vector const &u_initial,
                Vector const &v_initial, Vector const &a_initial) {
  switch (scheme) {
    case Config::Newmark:
      return std::make_shared<Newmark<Vector, Matrix, Function, dimension>>(
          matrices, u_initial, v_initial, a_initial, velocityDirichletNodes,
          velocityDirichletFunction);
    case Config::BackwardEuler:
      return std::make_shared<
          BackwardEuler<Vector, Matrix, Function, dimension>>(
          matrices, u_initial, v_initial, a_initial, velocityDirichletNodes,
          velocityDirichletFunction);
    default:
      assert(false);
  }
}
#endif