#ifndef DUNE_TECTONIC_GLOBALRATESTATEFRICTION_HH
#define DUNE_TECTONIC_GLOBALRATESTATEFRICTION_HH

#include <vector>

#include <dune/common/bitsetvector.hh>
#include <dune/common/fmatrix.hh>
#include <dune/common/fvector.hh>
#include <dune/grid/common/mcmgmapper.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/bvector.hh>

#include <dune/tectonic/globalfrictiondata.hh>
#include <dune/tectonic/globalfriction.hh>

template <class Matrix, class Vector, class ScalarFriction, class GridView>
class GlobalRateStateFriction : public GlobalFriction<Matrix, Vector> {
public:
  using GlobalFriction<Matrix, Vector>::block_size;
  using typename GlobalFriction<Matrix, Vector>::Friction;

private:
  using typename GlobalFriction<Matrix, Vector>::ScalarVector;

public:
  GlobalRateStateFriction(BoundaryPatch<GridView> const &frictionalBoundary,
                          GlobalFrictionData<block_size> const &frictionInfo,
                          ScalarVector const &weights,
                          ScalarVector const &weightedNormalStress)
      : restrictions(weightedNormalStress.size()) {
    auto zeroNonlinearity =
        std::make_shared<Friction>(std::make_shared<ZeroFunction>());
    auto const gridView = frictionalBoundary.gridView();

    Dune::MultipleCodimMultipleGeomTypeMapper<
        GridView, Dune::MCMGVertexLayout> const vertexMapper(gridView);
    for (auto it = gridView.template begin<block_size>();
         it != gridView.template end<block_size>(); ++it) {
      auto const i = vertexMapper.index(*it);
      auto const coordinate = it->geometry().corner(0);
      if (not frictionalBoundary.containsVertex(i)) {
        restrictions[i] = zeroNonlinearity;
        continue;
      }
      auto const fp = std::make_shared<ScalarFriction>(
          weights[i], weightedNormalStress[i], frictionInfo(coordinate));
      restrictions[i] = std::make_shared<Friction>(fp);
    }
  }

  void updateAlpha(ScalarVector const &alpha) override {
    for (size_t i = 0; i < restrictions.size(); ++i)
      restrictions[i]->updateAlpha(alpha[i]);
  }

  /*
    Return a restriction of the outer function to the i'th node.
  */
  std::shared_ptr<Friction> restriction(size_t i) const override {
    return restrictions[i];
  }

private:
  std::vector<std::shared_ptr<Friction>> restrictions;
};
#endif