#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include "state.hh"
#include "state/ageinglawstateupdater.cc"
#include "state/sliplawstateupdater.cc"

template <class ScalarVector, class Vector, class ContactCoupling, class FrictionCouplingPair>
auto initStateUpdater(
        Config::stateModel model,
        const std::vector<ScalarVector>& alpha_initial,
        const std::vector<std::shared_ptr<ContactCoupling>>& contactCouplings, // contains nonmortarBoundary
        const std::vector<std::shared_ptr<FrictionCouplingPair>>& couplings) // contains frictionInfo
-> std::shared_ptr<StateUpdater<ScalarVector, Vector>> {

  assert(contactCouplings.size() == couplings.size());

  std::vector<size_t> leafVertexCounts(alpha_initial.size(), 0);
  for (size_t i=0; i<leafVertexCounts.size(); i++) {
    leafVertexCounts[i] = alpha_initial[i].size();
  }

  auto stateUpdater = std::make_shared<StateUpdater<ScalarVector, Vector>>(leafVertexCounts);

  switch (model) {
    case Config::AgeingLaw:
      for (size_t i=0; i<couplings.size(); i++) {
        const auto& coupling = couplings[i];
        const auto nonmortarIdx = coupling->gridIdx_[0];

        auto localUpdater =  std::make_shared<AgeingLawStateUpdater<ScalarVector, Vector>>(
                    alpha_initial[nonmortarIdx], *contactCouplings[i]->nonmortarBoundary().getVertices(), coupling->frictionData().L(), coupling->frictionData().V0());
        localUpdater->setBodyIndex(nonmortarIdx);

        stateUpdater->addLocalUpdater(localUpdater);
      }
      break;
    case Config::SlipLaw:
      for (size_t i=0; i<couplings.size(); i++) {
        const auto& coupling = couplings[i];
        const auto nonmortarIdx = coupling->gridIdx_[0];

        auto localUpdater =  std::make_shared<SlipLawStateUpdater<ScalarVector, Vector>>(
                  alpha_initial[nonmortarIdx], *contactCouplings[i]->nonmortarBoundary().getVertices(), coupling->frictionData().L(), coupling->frictionData().V0());
        localUpdater->setBodyIndex(nonmortarIdx);

        stateUpdater->addLocalUpdater(localUpdater);
      }
      break;
    default:
      assert(false);
      break;
    }

    return stateUpdater;
}

#include "state_tmpl.cc"