#ifndef DUNE_TECTONIC_TIMESTEPPING_HH
#define DUNE_TECTONIC_TIMESTEPPING_HH
#include <dune/common/bitsetvector.hh>

template <class VectorType, class MatrixType, class FunctionType, int dim>
class TimeSteppingScheme {
public:
  void virtual nextTimeStep() = 0;
  void virtual setup(VectorType const &ell, double _tau, double time,
                     VectorType &problem_rhs, VectorType &problem_iterate,
                     MatrixType &problem_A) = 0;

  void virtual postProcess(VectorType const &problem_iterate) = 0;
  void virtual extractDisplacement(VectorType &displacement) const = 0;
  void virtual extractVelocity(VectorType &velocity) const = 0;

  TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> virtual *
  clone() = 0;
};

template <class VectorType, class MatrixType, class FunctionType, int dim>
class ImplicitEuler
    : public TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> {
public:
  ImplicitEuler(MatrixType const &_A, VectorType const &_u_initial,
                VectorType const &_ud_initial,
                Dune::BitSetVector<dim> const &_dirichletNodes,
                FunctionType const &_dirichletFunction);

  void virtual nextTimeStep() override;
  void virtual setup(VectorType const &, double, double, VectorType &,
                     VectorType &, MatrixType &) override;
  void virtual postProcess(VectorType const &) override;
  void virtual extractDisplacement(VectorType &) const override;
  void virtual extractVelocity(VectorType &) const override;

  TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> virtual *clone()
      override;

private:
  MatrixType const &A;
  VectorType u;
  VectorType ud;
  Dune::BitSetVector<dim> const &dirichletNodes;
  FunctionType const &dirichletFunction;

  VectorType u_old;
  VectorType ud_old;

  double tau;

  bool postProcessCalled = false;
};

template <class VectorType, class MatrixType, class FunctionType, int dim>
class Newmark
    : public TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> {
public:
  Newmark(MatrixType const &_A, MatrixType const &_B,
          VectorType const &_u_initial, VectorType const &_ud_initial,
          VectorType const &_udd_initial,
          Dune::BitSetVector<dim> const &_dirichletNodes,
          FunctionType const &_dirichletFunction);

  void virtual nextTimeStep() override;
  void virtual setup(VectorType const &, double, double, VectorType &,
                     VectorType &, MatrixType &) override;
  void virtual postProcess(VectorType const &) override;
  void virtual extractDisplacement(VectorType &) const override;
  void virtual extractVelocity(VectorType &) const override;

  TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> virtual *clone()
      override;

private:
  MatrixType const &A;
  MatrixType const &B;
  VectorType u;
  VectorType ud;
  VectorType udd;
  Dune::BitSetVector<dim> const &dirichletNodes;
  FunctionType const &dirichletFunction;

  VectorType u_old;
  VectorType ud_old;
  VectorType udd_old;

  double tau;

  bool postProcessCalled = false;
};

template <class VectorType, class MatrixType, class FunctionType, int dim>
class EulerPair
    : public TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> {
public:
  EulerPair(MatrixType const &_A, MatrixType const &_B,
            VectorType const &_u_initial, VectorType const &_ud_initial,
            Dune::BitSetVector<dim> const &_dirichletNodes,
            FunctionType const &_dirichletFunction);

  void virtual nextTimeStep() override;
  void virtual setup(VectorType const &, double, double, VectorType &,
                     VectorType &, MatrixType &) override;
  void virtual postProcess(VectorType const &) override;
  void virtual extractDisplacement(VectorType &) const override;
  void virtual extractVelocity(VectorType &) const override;

  TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> virtual *clone()
      override;

private:
  MatrixType const &A;
  MatrixType const &B;
  VectorType u;
  VectorType ud;
  Dune::BitSetVector<dim> const &dirichletNodes;
  FunctionType const &dirichletFunction;

  VectorType u_old;
  VectorType ud_old;

  double tau;

  bool postProcessCalled = false;
};

#endif