// -*- tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*-
// vi: set ts=8 sw=4 et sts=4:
#include <dune/common/exceptions.hh>

#include <dune/istl/matrixindexset.hh>
#include <dune/istl/io.hh>

#include <dune/geometry/quadraturerules.hh>
#include <dune/geometry/type.hh>
#include <dune/geometry/referenceelements.hh>

#include <dune/localfunctions/lagrange/pqkfactory.hh>


#include <dune/grid-glue/merging/contactmerge.hh>
//#include <dune/grid-glue/adapter/gridgluevtkwriter.hh>
//#include <dune/grid-glue/adapter/gridglueamirawriter.hh>

#include <dune/contact/common/dualbasisadapter.hh>
#include <dune/matrix-vector/addtodiagonal.hh>


template <class field_type, class GridView0, class GridView1>
void DualMortarCoupling<field_type, GridView0, GridView1>::assembleNonmortarLagrangeMatrix()
{
    // Create mapping from the global set of block dofs to the ones on the contact boundary
    std::vector<int> globalToLocal;
    nonmortarBoundary_.makeGlobalToLocal(globalToLocal);

    // clear matrix
    nonmortarLagrangeMatrix_ = Dune::BDMatrix<MatrixBlock>(nonmortarBoundary_.numVertices());
    nonmortarLagrangeMatrix_ = 0;

    const auto& indexSet = gridView0_.indexSet();

    // loop over all faces of the boundary patch
    for (const auto& nIt : nonmortarBoundary_) {

        const auto& geometry = nIt.geometry();
        const field_type numCorners = geometry.corners();

        ctype intElem = geometry.integrationElement(Dune::FieldVector<ctype,dim-1>(0));

        field_type sfI = (numCorners==3) ? intElem/6.0 : intElem/numCorners;

        // turn scalar element mass matrix into vector-valued one
        // and add element mass matrix to the global matrix

        // Get global node ids
        const auto& inside = nIt.inside();
        const auto& refElement = Dune::ReferenceElements<ctype, dim>::general(inside.type());

        // Add localMatrix to nonmortarLagrangeMat
        for (int i=0; i<refElement.size(nIt.indexInInside(),1,dim); i++) {
            // we can use subEntity here because we add all indices anyway
            int v = globalToLocal[indexSet.subIndex(inside,refElement.subEntity(nIt.indexInInside(),1,
                                                                          i,dim),dim)];
                Dune::MatrixVector::addToDiagonal(nonmortarLagrangeMatrix_[v][v],sfI);
        }

    }

}

template <class field_type, class GridView0, class GridView1>
void DualMortarCoupling<field_type, GridView0, GridView1>::setup()
{
    typedef Dune::PQkLocalFiniteElementCache<typename GridType1::ctype, field_type, GridType1::dimension,1> FiniteElementCache1;

    // cache for the dual functions on the boundary
    using DualCache = Dune::Contact::DualBasisAdapter<GridView0, field_type>;

    using Element0 = typename GridView0::template Codim<0>::Entity;
    using Element1 = typename GridView1::template Codim<0>::Entity;

    auto desc0 = [&] (const Element0& e, unsigned int face) {
        return nonmortarBoundary_.contains(e,face);
    };

    auto desc1 = [&] (const Element1& e, unsigned int face) {
        return mortarBoundary_.contains(e,face);
    };

    auto extract0 = std::make_shared<Extractor0>(gridView0_,desc0);
    auto extract1 = std::make_shared<Extractor1>(gridView1_,desc1);

    if (!gridGlueBackend_)
        gridGlueBackend_ = std::make_shared< Dune::GridGlue::ContactMerge<dimworld, ctype> >(overlap_);
    glue_ = std::make_shared<Glue>(extract0, extract1, gridGlueBackend_);
    auto& glue = *glue_;
    glue.build();

    std::cout << glue.size() << " remote intersections found." << std::endl;
    //GridGlueAmiraWriter::write<GlueType>(glue,debugPath_);

    // Restrict the hasObstacle fields to the part of the nonmortar boundary
    // where we could actually create a contact mapping
    BoundaryPatch0 boundaryWithMapping(gridView0);

    const auto& indexSet0 = gridView0_.indexSet();
    const auto& indexSet1 = gridView1_.indexSet();

    ///////////////////////////////////
    //  reducing nonmortar boundary
    /////////////////////////////////

    // Get all fine grid boundary segments that are totally covered by the grid-glue segments
    typedef std::pair<int,int> Pair;
    std::map<Pair,ctype> coveredArea, fullArea;

    // initialize with area of boundary faces
    for (const auto& bIt : nonmortarBoundary_) {
        const Pair p(indexSet0.index(bIt.inside()),bIt.indexInInside());
        fullArea[p] = bIt.geometry().volume();
        coveredArea[p] = 0;
    }

    // sum up the remote intersection areas to find out which are totally covered
    for (const auto& rIs : intersections(glue))
        coveredArea[Pair(indexSet0.index(rIs.inside()),rIs.indexInInside())] += rIs.geometry().volume();

    // add all fine grid faces that are totally covered by the contact mapping
    for (const auto& bIt : nonmortarBoundary_) {
        const auto& inside = bIt.inside();
        if(coveredArea[Pair(indexSet0.index(inside),bIt.indexInInside())]/
            fullArea[Pair(indexSet0.index(inside),bIt.indexInInside())] >= coveredArea_)
            boundaryWithMapping.addFace(inside, bIt.indexInInside());
    }

    //writeBoundary(boundaryWithMapping,debugPath_ + "relevantNonmortar");


    /** \todo replace by all fine grid segments which are totally covered by the RemoteIntersections. */
    //for (const auto& rIs : intersections(glue))
    //    boundaryWithMapping.addFace(rIs.inside(),rIs.indexInInside());

    printf("contact mapping could be built for %d of %d boundary segments.\n",
           boundaryWithMapping.numFaces(), nonmortarBoundary_.numFaces());

    nonmortarBoundary_ = boundaryWithMapping;
    mortarBoundary_.setup(gridView1_);
    for (const auto& rIs : intersections(glue))
        if (nonmortarBoundary_.contains(rIs.inside(),rIs.indexInInside()))
            mortarBoundary_.addFace(rIs.outside(),rIs.indexInOutside());


    // Assemble the diagonal matrix coupling the nonmortar side with the lagrange multiplyers there
    assembleNonmortarLagrangeMatrix();

    // The weak obstacle vector
    weakObstacle_.resize(nonmortarBoundary_.numVertices());
    weakObstacle_ = 0;

    // ///////////////////////////////////////////////////////////
    //   Get the occupation structure for the mortar matrix
    // ///////////////////////////////////////////////////////////

    /** \todo Also restrict mortar indices and don't use the whole grid level. */
    Dune::MatrixIndexSet mortarIndices(nonmortarBoundary_.numVertices(), grid1_->size(dim));

    // Create mapping from the global set of block dofs to the ones on the contact boundary
    std::vector<int> globalToLocal;
    nonmortarBoundary_.makeGlobalToLocal(globalToLocal);

    // loop over all intersections
    for (const auto& rIs : intersections(glue)) {

        if (!nonmortarBoundary_.contains(rIs.inside(),rIs.indexInInside()))
            continue;

        const auto& inside = rIs.inside();
        const auto& outside = rIs.outside();

        const auto& domainRefElement = Dune::ReferenceElements<ctype, dim>::general(inside.type());
        const auto& targetRefElement = Dune::ReferenceElements<ctype, dim>::general(outside.type());

        int nDomainVertices = domainRefElement.size(dim);
        int nTargetVertices = targetRefElement.size(dim);

        for (int j=0; j<nDomainVertices; j++) {

            int localDomainIdx = globalToLocal[indexSet0.subIndex(inside,j,dim)];

            // if the vertex is not contained in the restricted contact boundary then dismiss it
            if (localDomainIdx == -1)
                continue;

            for (int k=0; k<nTargetVertices; k++) {
                int globalTargetIdx = indexSet1.subIndex(outside,k,dim);
                if (!mortarBoundary_.containsVertex(globalTargetIdx))
                    continue;

                mortarIndices.add(localDomainIdx, globalTargetIdx);
            }
        }
    }

    mortarIndices.exportIdx(mortarLagrangeMatrix_);

    // Clear it
    mortarLagrangeMatrix_ = 0;


    //cache of local bases
    FiniteElementCache1 cache1;

    std::unique_ptr<DualCache> dualCache;
    dualCache = std::make_unique< Dune::Contact::DualBasisAdapterGlobal<GridView0, field_type> >();

    std::vector<Dune::FieldVector<ctype,dim> > avNormals;
    avNormals = nonmortarBoundary_.getNormals();

    // loop over all intersections and compute the matrix entries
    for (const auto& rIs : intersections(glue)) {


        const auto& inside = rIs.inside();

        if (!nonmortarBoundary_.contains(rIs.inside(),rIs.indexInInside()))
            continue;

        const auto& outside = rIs.outside();

        // types of the elements supporting the boundary segments in question
        Dune::GeometryType nonmortarEType = inside.type();
        Dune::GeometryType mortarEType    = outside.type();

        const auto& domainRefElement = Dune::ReferenceElements<ctype, dim>::general(nonmortarEType);

        const auto& targetRefElement = Dune::ReferenceElements<ctype, dim>::general(mortarEType);

        int noOfMortarVec = targetRefElement.size(dim);

        Dune::GeometryType nmFaceType = domainRefElement.type(rIs.indexInInside(),1);
        Dune::GeometryType mFaceType  = targetRefElement.type(rIs.indexInOutside(),1);

        // Select a quadrature rule
        // 2 in 2d and for integration over triangles in 3d.  If one (or both) of the two faces involved
        // are quadrilaterals, then the quad order has to be risen to 3 (4).
        int quadOrder = 2 + (!nmFaceType.isSimplex()) + (!mFaceType.isSimplex());
        const auto& quadRule = Dune::QuadratureRules<ctype, dim-1>::rule(rIs.type(), quadOrder);

        //const typename FiniteElementCache0::FiniteElementType& nonmortarFiniteElement = cache0.get(nonmortarEType);
        const auto& mortarFiniteElement = cache1.get(mortarEType);
        dualCache->bind(inside, rIs.indexInInside());

        std::vector<Dune::FieldVector<field_type,1> > mortarQuadValues, dualQuadValues;

        const auto& rGeom = rIs.geometry();
        const auto& rGeomOutside = rIs.geometryOutside();
        const auto& rGeomInInside = rIs.geometryInInside();
        const auto& rGeomInOutside = rIs.geometryInOutside();

        int nNonmortarFaceNodes = domainRefElement.size(rIs.indexInInside(),1,dim);
        std::vector<int> nonmortarFaceNodes;
        for (int i=0; i<nNonmortarFaceNodes; i++) {
          int faceIdxi = domainRefElement.subEntity(rIs.indexInInside(), 1, i, dim);
          nonmortarFaceNodes.push_back(faceIdxi);
        }

        for (const auto& quadPt : quadRule) {

            // compute integration element of overlap
            ctype integrationElement = rGeom.integrationElement(quadPt.position());

            // quadrature point positions on the reference element
            Dune::FieldVector<ctype,dim> nonmortarQuadPos = rGeomInInside.global(quadPt.position());
            Dune::FieldVector<ctype,dim> mortarQuadPos    = rGeomInOutside.global(quadPt.position());

            // The current quadrature point in world coordinates
            Dune::FieldVector<field_type,dim> nonmortarQpWorld = rGeom.global(quadPt.position());
            Dune::FieldVector<field_type,dim> mortarQpWorld    = rGeomOutside.global(quadPt.position());;

            // the gap direction (normal * gapValue)
            Dune::FieldVector<field_type,dim> gapVector = mortarQpWorld  - nonmortarQpWorld;

            //evaluate all shapefunctions at the quadrature point
            //nonmortarFiniteElement.localBasis().evaluateFunction(nonmortarQuadPos,nonmortarQuadValues);
            mortarFiniteElement.localBasis().evaluateFunction(mortarQuadPos,mortarQuadValues);
            dualCache->evaluateFunction(nonmortarQuadPos,dualQuadValues);

            // loop over all Lagrange multiplier shape functions
            for (int j=0; j<nNonmortarFaceNodes; j++) {

                int globalDomainIdx = indexSet0.subIndex(inside,nonmortarFaceNodes[j],dim);
                int rowIdx = globalToLocal[globalDomainIdx];

                weakObstacle_[rowIdx][0] += integrationElement * quadPt.weight()
                    * dualQuadValues[nonmortarFaceNodes[j]] * (gapVector*avNormals[globalDomainIdx]);

                // loop over all mortar shape functions
                for (int k=0; k<noOfMortarVec; k++) {

                    int colIdx  = indexSet1.subIndex(outside, k, dim);
                    if (!mortarBoundary_.containsVertex(colIdx))
                        continue;

                    // Integrate over the product of two shape functions
                    field_type mortarEntry =  integrationElement* quadPt.weight()* dualQuadValues[nonmortarFaceNodes[j]]* mortarQuadValues[k];

                    Dune::MatrixVector::addToDiagonal(mortarLagrangeMatrix_[rowIdx][colIdx], mortarEntry);

                }

            }

        }

    }

    // ///////////////////////////////////////
    //    Compute M = D^{-1} \hat{M}
    // ///////////////////////////////////////

    Dune::BCRSMatrix<MatrixBlock>& M  = mortarLagrangeMatrix_;
    Dune::BDMatrix<MatrixBlock>& D    = nonmortarLagrangeMatrix_;

    // First compute D^{-1}
    D.invert();

    // Then the matrix product D^{-1} \hat{M}
    for (auto rowIt = M.begin(); rowIt != M.end(); ++rowIt) {
        const auto rowIndex = rowIt.index();
        for (auto& entry : *rowIt)
            entry.leftmultiply(D[rowIndex][rowIndex]);
    }

    // weakObstacles in transformed basis = D^{-1}*weakObstacle_
    for(size_t rowIdx=0; rowIdx<weakObstacle_.size(); rowIdx++)
        weakObstacle_[rowIdx] *= D[rowIdx][rowIdx][0][0];

    gridGlueBackend_->clear();
}