#ifndef DUNE_TECTONIC_GLOBALRATESTATEFRICTION_HH
#define DUNE_TECTONIC_GLOBALRATESTATEFRICTION_HH

#include <vector>

#include <dune/common/bitsetvector.hh>
#include <dune/common/fmatrix.hh>
#include <dune/common/fvector.hh>
#include <dune/grid/common/mcmgmapper.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/bvector.hh>

#include "../../spatial-solving/contact/dualmortarcoupling.hh"

#include "globalfrictiondata.hh"
#include "globalfriction.hh"
#include "frictioncouplingpair.hh"

#include "../../utils/geocoordinate.hh"
#include "../../utils/index-in-sorted-range.hh"

template <class Matrix, class Vector, class ScalarFriction, class GridType>
class GlobalRateStateFriction : public GlobalFriction<Matrix, Vector> {
public:
  using GlobalFriction<Matrix, Vector>::block_size;
  using typename GlobalFriction<Matrix, Vector>::LocalNonlinearity;

private:
  using Base = GlobalFriction<Matrix, Vector>;

  using field_type = typename Vector::field_type;
  using typename Base::ScalarVector;
  using typename Base::LocalVectorType;

  using FrictionCoupling = FrictionCouplingPair<GridType, LocalVectorType, field_type>;
  using ContactCoupling =  DualMortarCoupling<field_type, GridType>;

  size_t bodyIndex(const size_t globalIdx) {
     size_t i = offSet_.size()-1;

     for (; i>0; ) {
         if (globalIdx >= offSet_[i])
             break;
         i--;
     }
     return i;
  }

public:
  GlobalRateStateFriction(const std::vector<std::shared_ptr<ContactCoupling>>& contactCouplings, // contains nonmortarBoundary
                          const std::vector<std::shared_ptr<FrictionCoupling>>& couplings, // contains frictionInfo
                          const std::vector<ScalarVector>& weights,
                          const std::vector<ScalarVector>& weightedNormalStress)
      : restrictions_(), localToGlobal_(), zeroFriction_() {

    assert(contactCouplings.size() == couplings.size());
    assert(weights.size() == weightedNormalStress.size());

    const auto nBodies = weights.size();
    offSet_.resize(nBodies, 0);
    for (size_t i=1; i<nBodies; i++) {
        offSet_[i] = offSet_[i-1] + weights[i-1].size();
    }

    std::vector<std::vector<int>> nonmortarBodies(nBodies); // first index body, second index coupling
    for (size_t i=0; i<contactCouplings.size(); i++) {
        const auto nonmortarIdx = couplings[i]->gridIdx_[0];
        nonmortarBodies[nonmortarIdx].emplace_back(i);
    }

    for (size_t bodyIdx=0; bodyIdx<nBodies; bodyIdx++) {
        const auto& couplingIndices = nonmortarBodies[bodyIdx];

        if (couplingIndices.size()==0)
            continue;

        const auto gridView = contactCouplings[couplingIndices[0]]->nonmortarBoundary().gridView();

        Dune::MultipleCodimMultipleGeomTypeMapper<
            decltype(gridView), Dune::MCMGVertexLayout> const vertexMapper(gridView, Dune::mcmgVertexLayout());

        for (auto it = gridView.template begin<block_size>(); it != gridView.template end<block_size>(); ++it) {
            const auto i = vertexMapper.index(*it);

            for (size_t j=0; j<couplingIndices.size(); j++) {
                const auto couplingIdx = couplingIndices[j];

                if (not contactCouplings[couplingIdx]->nonmortarBoundary().containsVertex(i))
                  continue;

                localToGlobal_.emplace_back(offSet_[bodyIdx] + i);
                restrictions_.emplace_back(weights[bodyIdx][i], weightedNormalStress[bodyIdx][i],
                                          couplings[j]->frictionData()(geoToPoint(it->geometry())));
                break;
            }
        }
    }
  }

  void updateAlpha(const std::vector<ScalarVector>& alpha) override {
    //print(alpha, "alpha");
    for (size_t j = 0; j < restrictions_.size(); ++j) {
      const auto globalDof = localToGlobal_[j];
      const auto bodyIdx = bodyIndex(globalDof);
      size_t bodyDof = globalDof - offSet_[bodyIdx];

      restrictions_[j].updateAlpha(alpha[bodyIdx][bodyDof]);
    }
  }

  /*
    Return a restriction of the outer function to the i'th node.
  */
  LocalNonlinearity const &restriction(size_t i) const override {
    auto const index = indexInSortedRange(localToGlobal_, i);
    if (index == localToGlobal_.size())
      return zeroFriction_;
    return restrictions_[index];
  }

private:
  std::vector<WrappedScalarFriction<block_size, ScalarFriction>> restrictions_;
  std::vector<size_t> offSet_; // index off-set by body id
  std::vector<size_t> localToGlobal_;
  WrappedScalarFriction<block_size, ZeroFunction> const zeroFriction_;
};
#endif