#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <dune/fufem/arithmetic.hh>
#include "timestepping.hh"

template <class VectorType, class MatrixType, class FunctionType, int dim>
ImplicitEuler<VectorType, MatrixType, FunctionType, dim>::ImplicitEuler(
    MatrixType const &_A, VectorType const &_u_initial,
    VectorType const &_ud_initial,
    Dune::BitSetVector<dim> const &_dirichletNodes,
    FunctionType const &_dirichletFunction)
    : A(_A),
      u(_u_initial),
      ud(_ud_initial),
      dirichletNodes(_dirichletNodes),
      dirichletFunction(_dirichletFunction) {}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void ImplicitEuler<VectorType, MatrixType, FunctionType, dim>::nextTimeStep() {
  ud_old = ud;
  u_old = u;
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void ImplicitEuler<VectorType, MatrixType, FunctionType, dim>::setup(
    VectorType const &ell, double _tau, double time, VectorType &problem_rhs,
    VectorType &problem_iterate, MatrixType &problem_A) {
  postProcessCalled = false;

  tau = _tau;

  problem_rhs = ell;
  A.mmv(u_old, problem_rhs);

  // For fixed tau, we'd only really have to do this once
  problem_A = A;
  problem_A *= tau;

  // ud_old makes a good initial iterate; we could use anything, though
  problem_iterate = ud_old;

  for (size_t i = 0; i < dirichletNodes.size(); ++i)
    switch (dirichletNodes[i].count()) {
      case 0:
        continue;
      case dim:
        problem_iterate[i] = 0;
        dirichletFunction.evaluate(time, problem_iterate[i][0]);
        break;
      case 1:
        if (dirichletNodes[i][0]) {
          dirichletFunction.evaluate(time, problem_iterate[i][0]);
          break;
        }
        if (dirichletNodes[i][1]) {
          problem_iterate[i][1] = 0;
          break;
        }
        assert(false);
      default:
        assert(false);
    }
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void ImplicitEuler<VectorType, MatrixType, FunctionType, dim>::postProcess(
    VectorType const &problem_iterate) {
  postProcessCalled = true;

  ud = problem_iterate;

  u = u_old;
  Arithmetic::addProduct(u, tau, ud);
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void ImplicitEuler<VectorType, MatrixType, FunctionType,
                   dim>::extractDisplacement(VectorType &displacement) const {
  if (!postProcessCalled)
    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");

  displacement = u;
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void ImplicitEuler<VectorType, MatrixType, FunctionType, dim>::extractVelocity(
    VectorType &velocity) const {
  if (!postProcessCalled)
    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");

  velocity = ud;
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> *
ImplicitEuler<VectorType, MatrixType, FunctionType, dim>::clone() {
  return new ImplicitEuler<VectorType, MatrixType, FunctionType, dim>(*this);
}
template <class VectorType, class MatrixType, class FunctionType, int dim>
ImplicitTwoStep<VectorType, MatrixType, FunctionType, dim>::ImplicitTwoStep(
    MatrixType const &_A, VectorType const &_u_initial,
    VectorType const &_ud_initial,
    Dune::BitSetVector<dim> const &_dirichletNodes,
    FunctionType const &_dirichletFunction)
    : A(_A),
      u(_u_initial),
      ud(_ud_initial),
      dirichletNodes(_dirichletNodes),
      dirichletFunction(_dirichletFunction) {}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void
ImplicitTwoStep<VectorType, MatrixType, FunctionType, dim>::nextTimeStep() {
  u_old_old = u_old;
  ud_old = ud;
  u_old = u;
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void ImplicitTwoStep<VectorType, MatrixType, FunctionType, dim>::setup(
    VectorType const &ell, double _tau, double time, VectorType &problem_rhs,
    VectorType &problem_iterate, MatrixType &problem_A) {
  postProcessCalled = false;

  tau = _tau;

  switch (state) {
    // Perform an implicit Euler step since we lack information
    case NO_SETUP:
      state = FIRST_SETUP;

      problem_rhs = ell;
      A.mmv(u_old, problem_rhs);

      problem_A = A;
      problem_A *= tau;
      break;
    case FIRST_SETUP:
      state = SECOND_SETUP;
    // FALLTHROUGH
    case SECOND_SETUP:
      problem_rhs = ell;
      A.usmv(-4.0 / 3.0, u_old, problem_rhs);
      A.usmv(+1.0 / 3.0, u_old_old, problem_rhs);

      // For fixed tau, we'd only really have to do this once
      problem_A = A;
      problem_A *= 2.0 / 3.0 * tau;
      break;
    default:
      assert(false);
  }

  // ud_old makes a good initial iterate; we could use anything, though
  problem_iterate = ud_old;

  for (size_t i = 0; i < dirichletNodes.size(); ++i)
    switch (dirichletNodes[i].count()) {
      case 0:
        continue;
      case dim:
        problem_iterate[i] = 0;
        dirichletFunction.evaluate(time, problem_iterate[i][0]);
        break;
      case 1:
        if (dirichletNodes[i][0]) {
          dirichletFunction.evaluate(time, problem_iterate[i][0]);
          break;
        }
        if (dirichletNodes[i][1]) {
          problem_iterate[i][1] = 0;
          break;
        }
        assert(false);
      default:
        assert(false);
    }
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void ImplicitTwoStep<VectorType, MatrixType, FunctionType, dim>::postProcess(
    VectorType const &problem_iterate) {
  postProcessCalled = true;

  ud = problem_iterate;

  switch (state) {
    case FIRST_SETUP:
      u = u_old;
      Arithmetic::addProduct(u, tau, ud);
      break;
    case SECOND_SETUP:
      u = 0.0;
      Arithmetic::addProduct(u, tau, ud);
      Arithmetic::addProduct(u, 2.0, u_old);
      Arithmetic::addProduct(u, -.5, u_old_old);
      u *= 2.0 / 3.0;
      break;
    default:
      assert(false);
  }
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void ImplicitTwoStep<VectorType, MatrixType, FunctionType,
                     dim>::extractDisplacement(VectorType &displacement) const {
  if (!postProcessCalled)
    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");

  displacement = u;
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void ImplicitTwoStep<VectorType, MatrixType, FunctionType,
                     dim>::extractVelocity(VectorType &velocity) const {
  if (!postProcessCalled)
    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");

  velocity = ud;
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> *
ImplicitTwoStep<VectorType, MatrixType, FunctionType, dim>::clone() {
  return new ImplicitTwoStep<VectorType, MatrixType, FunctionType, dim>(*this);
}
template <class VectorType, class MatrixType, class FunctionType, int dim>
Newmark<VectorType, MatrixType, FunctionType, dim>::Newmark(
    MatrixType const &_A, MatrixType const &_B, VectorType const &_u_initial,
    VectorType const &_ud_initial, VectorType const &_udd_initial,
    Dune::BitSetVector<dim> const &_dirichletNodes,
    FunctionType const &_dirichletFunction)
    : A(_A),
      B(_B),
      u(_u_initial),
      ud(_ud_initial),
      udd(_udd_initial),
      dirichletNodes(_dirichletNodes),
      dirichletFunction(_dirichletFunction) {}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void Newmark<VectorType, MatrixType, FunctionType, dim>::nextTimeStep() {
  udd_old = udd;
  ud_old = ud;
  u_old = u;
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void Newmark<VectorType, MatrixType, FunctionType, dim>::setup(
    VectorType const &ell, double _tau, double time, VectorType &problem_rhs,
    VectorType &problem_iterate, MatrixType &problem_A) {
  postProcessCalled = false;

  tau = _tau;

  problem_rhs = ell;
  B.usmv(2.0 / tau, ud_old, problem_rhs);
  B.usmv(1.0, udd_old, problem_rhs);
  A.usmv(-1.0, u_old, problem_rhs);
  A.usmv(-tau / 2.0, ud_old, problem_rhs);

  // For fixed tau, we'd only really have to do this once
  problem_A = A;
  problem_A *= tau / 2.0;
  Arithmetic::addProduct(problem_A, 2.0 / tau, B);

  // ud_old makes a good initial iterate; we could use anything, though
  problem_iterate = ud_old;

  for (size_t i = 0; i < dirichletNodes.size(); ++i)
    switch (dirichletNodes[i].count()) {
      case 0:
        continue;
      case dim:
        problem_iterate[i] = 0;
        dirichletFunction.evaluate(time, problem_iterate[i][0]);
        break;
      case 1:
        if (dirichletNodes[i][0]) {
          dirichletFunction.evaluate(time, problem_iterate[i][0]);
          break;
        }
        if (dirichletNodes[i][1]) {
          problem_iterate[i][1] = 0;
          break;
        }
        assert(false);
      default:
        assert(false);
    }
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void Newmark<VectorType, MatrixType, FunctionType, dim>::postProcess(
    VectorType const &problem_iterate) {
  postProcessCalled = true;

  ud = problem_iterate;

  u = u_old;
  Arithmetic::addProduct(u, tau / 2.0, ud);
  Arithmetic::addProduct(u, tau / 2.0, ud_old);

  udd = 0;
  Arithmetic::addProduct(udd, 2.0 / tau, ud);
  Arithmetic::addProduct(udd, -2.0 / tau, ud_old);
  Arithmetic::addProduct(udd, -1.0, udd_old);
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void Newmark<VectorType, MatrixType, FunctionType, dim>::extractDisplacement(
    VectorType &displacement) const {
  if (!postProcessCalled)
    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");

  displacement = u;
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
void Newmark<VectorType, MatrixType, FunctionType, dim>::extractVelocity(
    VectorType &velocity) const {
  if (!postProcessCalled)
    DUNE_THROW(Dune::Exception, "It seems you forgot to call postProcess!");

  velocity = ud;
}

template <class VectorType, class MatrixType, class FunctionType, int dim>
TimeSteppingScheme<VectorType, MatrixType, FunctionType, dim> *
Newmark<VectorType, MatrixType, FunctionType, dim>::clone() {
  return new Newmark<VectorType, MatrixType, FunctionType, dim>(*this);
}

#include "timestepping_tmpl.cc"