#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <iostream>
#include <fstream>
#include <vector>
#include <exception>

#include <dune/common/exceptions.hh>
#include <dune/common/parallel/mpihelper.hh>
#include <dune/common/stdstreams.hh>
#include <dune/common/fvector.hh>
#include <dune/common/function.hh>
#include <dune/common/bitsetvector.hh>

//#include <dune/fufem/assemblers/assembler.hh>
#include <dune/fufem/functiontools/basisinterpolator.hh>
#include <dune/fufem/boundarypatch.hh>
#include <dune/fufem/functionspacebases/p1nodalbasis.hh>
#include <dune/fufem/assemblers/boundaryfunctionalassembler.hh>
#include <dune/fufem/assemblers/localassemblers/l2functionalassembler.hh>
#include <dune/fufem/functions/constantfunction.hh>

#include <dune/fufem/functions/basisgridfunction.hh>

#include <dune/grid/uggrid.hh>

#include <dune/grid-glue/adapter/gridgluevtkwriter.hh>
#include <dune/grid-glue/merging/merger.hh>

#include <dune/contact/projections/normalprojection.hh>
#include <dune/contact/common/couplingpair.hh>
#include <dune/contact/assemblers/dualmortarcoupling.hh>
#include <dune/contact/assemblers/nbodyassembler.hh>
#include <dune/contact/common/dualbasisadapter.hh>

#include "../utils/debugutils.hh"

//#include <dune/tectonic/transformedglobalratestatefriction.hh>

#include "common.hh"

const int dim = 2;
const int n = 5;
const bool simplexGrid = true;

const std::string path = "";
const std::string outputFile = "gridgluefrictiontest.log";

#if HAVE_UG
    using GridType = typename Dune::UGGrid<dim>;
#else
#error No UG!
#endif
    using LevelView = typename GridType::LevelGridView;
    using LeafView = typename GridType::LeafGridView;
    using LevelBoundaryPatch = BoundaryPatch<LevelView>;
    using LeafBoundaryPatch = BoundaryPatch<LeafView>;

    using c_type = double;
    using GlobalCoords = Dune::FieldVector<c_type, dim>;
    using Vector = Dune::FieldVector<c_type, 1>;
    using BlockVector = Dune::BlockVector<Vector>;

    using CouplingPair = Dune::Contact::CouplingPair<GridType, GridType, c_type>;
    using CouplingType = Dune::Contact::DualMortarCoupling<c_type, GridType>;
    using NBodyAssembler = Dune::Contact::NBodyAssembler<GridType, BlockVector>;



// boundary function along y=1.0 line
template <class VectorType>
class BoundaryFunction0 : public Dune::VirtualFunction<VectorType, c_type> {
public:
    void evaluate(const VectorType& x, c_type& res) const override {
        if (isClose(x[1], 1.0)) {
            res = 1.0 ;// +std::sin(x[0]);
        } else {
            res = 1.0;
        }
    }
};

// boundary function along y=1.0 line
template <class VectorType>
class BoundaryFunction1 : public Dune::VirtualFunction<VectorType, c_type> {
public:
    void evaluate(const VectorType& x, c_type& res) const override{
        if (isClose(x[1], 1.0)) {
            res = 2.0 ;//+ std::cos(x[0]);
        } else {
            res = 2.0;
        }
    }
};
/*
template <class GridView>
void dualWeights(const BoundaryPatch<GridView>& boundaryPatch, BlockVector& weights, bool initializeVector = true) {

    typedef typename BoundaryPatch<GridView>::iterator BoundaryIterator;

    /*typedef typename LocalBoundaryFunctionalAssemblerType::LocalVector LocalVector;
    typedef typename TrialBasis::LinearCombination LinearCombination;

    int rows = tBasis_.size();


        const auto inside = it->inside();

        // get shape functions
        DualP1LocalFiniteElement<class D, class R, dim>
  class
        const typename TrialBasis::LocalFiniteElement& lFE = tBasis_.getLocalFiniteElement(inside);

        LocalVector localB(tFE.localBasis().size());

    }*/
/*
    // cache for the dual functions on the boundary
    using DualCache = Dune::Contact::DualBasisAdapter<GridView, field_type>;
    std::unique_ptr<DualCache> dualCache;
    dualCache = std::make_unique< Dune::Contact::DualBasisAdapterGlobal<GridView, field_type> >();

    if (initializeVector) {
        weights.resize(rows);
        std::fill(weights.begin(), weights.end(), 0.0);
    }

    // loop over all boundary intersections
    BoundaryIterator it = boundaryPatch.begin();
    BoundaryIterator end = boundaryPatch.end();

    for (; it != end; ++it) {
        const auto& inside = it->inside();

        if (!boundaryPatch.contains(inside, it->indexInInside()))
            continue;

        // types of the elements supporting the boundary segments in question
        Dune::GeometryType nonmortarEType = inside.type();

        const auto& domainRefElement = Dune::ReferenceElements<ctype, dim>::general(nonmortarEType);

        int noOfMortarVec = targetRefElement.size(dim);

        Dune::GeometryType nmFaceType = domainRefElement.type(rIs.indexInInside(),1);


        // Select a quadrature rule
        // 2 in 2d and for integration over triangles in 3d.  If one (or both) of the two faces involved
        // are quadrilaterals, then the quad order has to be risen to 3 (4).
        int quadOrder = 2 + (!nmFaceType.isSimplex());
        const auto& quadRule = Dune::QuadratureRules<ctype, dim-1>::rule(it->type(), quadOrder);

        dualCache->bind(inside, it->indexInInside());



        const auto& rGeomInInside = it->geometryInInside();

        int nNonmortarFaceNodes = domainRefElement.size(it->indexInInside(),1,dim);
        std::vector<int> nonmortarFaceNodes;
        for (int i=0; i<nNonmortarFaceNodes; i++) {
          int faceIdxi = domainRefElement.subEntity(it->indexInInside(), 1, i, dim);
          nonmortarFaceNodes.push_back(faceIdxi);
        }

        // store values of shape functions
        std::vector<Dune::FieldVector<field_type,1>> values(tFE.localBasis().size());

        // get geometry and store it
        const auto& rGeom = it->geometry();

        localVector = 0.0;

        // get quadrature rule
        QuadratureRuleKey tFEquad(tFE);
        QuadratureRuleKey quadKey = tFEquad.product(functionQuadKey_);
        const auto& quad = QuadratureRuleCache<double, dim>::rule(quadKey);

        // loop over quadrature points
        for (const auto& pt : quadRule) {

            // get quadrature point
            const Dune::FieldVector<ctype,dim>& quadPos = quad[pt].position();

            // get integration factor
            const ctype integrationElement = rGeom.integrationElement(quadPos);

            // evaluate basis functions
            dualCache->evaluateFunction(quadPos, values);

            // compute values of function
            T f_pos;
            const GridFunction* gf = dynamic_cast<const GridFunction*>(&f_);
            if (gf and gf->isDefinedOn(element))
                gf->evaluateLocal(element, quadPos, f_pos);
            else
                f_.evaluate(geometry.global(quadPos), f_pos);

            // and vector entries
            for(size_t i=0; i<values.size(); ++i)
            {
                localVector[i].axpy(values[i]*quad[pt].weight()*integrationElement, f_pos);
            }
        }

        for (size_t i=0; i<tFE.localBasis().size(); ++i) {
            int idx = tBasis_.index(inside, i);
            b[idx] += localB[i];
        }
}*/

int main(int argc, char *argv[]) { try {
    Dune::MPIHelper::instance(argc, argv);

    std::ofstream out(path + outputFile);
    std::streambuf *coutbuf = std::cout.rdbuf(); //save old buffer
    std::cout.rdbuf(out.rdbuf()); //redirect std::cout to outputFile

    std::cout << "------------------------------" << std::endl;
    std::cout << "-- Grid-Glue Friction Test: --" << std::endl;
    std::cout << "------------------------------" << std::endl << std::endl;

    // building grids
    std::vector<std::shared_ptr<GridType>> grids(2);

    GlobalCoords lowerLeft0({0, 0});
    GlobalCoords upperRight0({2, 1});
    buildGrid(lowerLeft0, upperRight0, n, grids[0]);

    GlobalCoords lowerLeft1({0, 1});
    GlobalCoords upperRight1({2, 2});
    buildGrid(lowerLeft1, upperRight1, n, grids[1]);

    // writing grids
    for (size_t i=0; i<grids.size(); i++) {
        const auto& levelView = grids[i]->levelGridView(0);
        writeToVTK(levelView, path, "body_" + std::to_string(i) + "_level0");
    }

    // compute coupling boundaries
    LevelView gridView0 = grids[0]->levelGridView(0);
    LevelView gridView1 = grids[1]->levelGridView(0);
    LevelBoundaryPatch upper(gridView0);
    LevelBoundaryPatch lower(gridView1);

    lower.insertFacesByProperty([&](typename LevelView::Intersection const &in) {
        return xyBetween(lowerLeft1, {upperRight1[0], lowerLeft1[1]}, in.geometry().center());
    });

    upper.insertFacesByProperty([&](typename LevelView::Intersection const &in) {
        return xyBetween({lowerLeft0[0], upperRight0[1]}, upperRight0, in.geometry().center());
    });

    // set contact coupling
    Dune::Contact::NormalProjection<LeafBoundaryPatch> contactProjection;

    CouplingPair coupling;
    coupling.set(0, 1, upper, lower, 0.1, CouplingPair::CouplingType::STICK_SLIP, contactProjection, nullptr);

   /* double coveredArea_ = 0.8;
    CouplingType contactCoupling;
    contactCoupling.setGrids(*grids[0], *grids[1]);
    contactCoupling.setupContactPatch(*coupling.patch0(),*coupling.patch1());
    contactCoupling.gridGlueBackend_ = coupling.backend();
    contactCoupling.setCoveredArea(coveredArea_);
    contactCoupling.setup();*/

    // set nBodyAssembler
    using NBodyAssembler = Dune::Contact::NBodyAssembler<GridType, BlockVector>;
    NBodyAssembler nBodyAssembler(grids.size(), 1);

    std::vector<const GridType*> grids_ptr(grids.size());
    for (size_t i=0; i<grids_ptr.size(); i++) {
        grids_ptr[i] = grids[i].get();
    }

    nBodyAssembler.setGrids(grids_ptr);
    nBodyAssembler.setCoupling(coupling, 0);

    nBodyAssembler.assembleTransferOperator();
    nBodyAssembler.assembleObstacle();

    // define basis
    using Basis = P1NodalBasis<LevelView, double>;
    Basis basis0(grids[0]->levelGridView(0));
    Basis basis1(grids[1]->levelGridView(0));

    std::cout << "--------------" << std::endl;
    std::cout << "--- Basis0 ---" << std::endl;
    std::cout << "--------------" << std::endl;
    printBasisDofLocation(basis0);

    std::cout << "--------------" << std::endl;
    std::cout << "--- Basis1 ---" << std::endl;
    std::cout << "--------------" << std::endl;
    printBasisDofLocation(basis1);

    // set grid functions on coupling boundary
    std::vector<BlockVector> f(2);

    BasisInterpolator<Basis, BlockVector, BoundaryFunction0<GlobalCoords>> interpolator0;
    interpolator0.interpolate(basis0, f[0], BoundaryFunction0<GlobalCoords>());

    BasisInterpolator<Basis, BlockVector, BoundaryFunction1<GlobalCoords>> interpolator1;
    interpolator1.interpolate(basis1, f[1], BoundaryFunction1<GlobalCoords>());

    BlockVector transformedF;
    nBodyAssembler.nodalToTransformed(f, transformedF);

    std::vector<Dune::BitSetVector<1>> boundaryVertices(2);
    const auto& mortarCoupling = nBodyAssembler.getContactCouplings()[0];
    const auto& nonmortarBoundary = mortarCoupling->nonmortarBoundary();
    const auto& mortarBoundary = mortarCoupling->mortarBoundary();

    nonmortarBoundary.getVertices(boundaryVertices[0]);
    mortarBoundary.getVertices(boundaryVertices[1]);

    print(f[0], "f0: ");
    print(boundaryVertices[0], "nonmortarBoundary: ");
    print(f[1], "f1: ");
    print(boundaryVertices[1], "mortarBoundary: ");
    print(transformedF, "transformedF: ");
    writeToVTK(basis0, f[0], path, "body_0_level0");
    writeToVTK(basis1, f[1], path, "body_1_level0");

    print(mortarCoupling->mortarLagrangeMatrix(), "M: ");
    print(mortarCoupling->nonmortarLagrangeMatrix(), "D: ");

    std::vector<BlockVector> postprocessed(2);
    nBodyAssembler.postprocess(transformedF, postprocessed);

    print(postprocessed, "postprocessed: ");

    const auto& contactCouplings = nBodyAssembler.getContactCouplings();
    const auto& contactCoupling = contactCouplings[0];
    const auto& glue = *contactCoupling->getGlue();

    size_t isCount = 0;
    for (const auto& rIs : intersections(glue)) {
        std::cout << "intersection id: " << isCount
                  << " insideElement: " << gridView0.indexSet().index(rIs.inside())
                  << " outsideElement: " << gridView1.indexSet().index(rIs.outside()) << std::endl;
        isCount++;
    }

    for (size_t i=0; i<isCount; i++) {
        const auto& is = glue.getIntersection(i);
        std::cout << "intersection id: " << i
                  << " insideElement: " << gridView0.indexSet().index(is.inside())
                  << " outsideElement: " << gridView1.indexSet().index(is.outside()) << std::endl;
    }

   /* using DualBasis = ;
    DualBasis dualBasis;
    BoundaryFunctionalAssembler<DualBasis> bAssembler(dualBasis, nonmortarBoundary);

    ConstantFunction<GlobalCoords, Vector> oneFunction(1);

    BlockVector b;
    L2FunctionalAssembler localAssembler(oneFunction);
    bAssembler.assemble(localAssembler, b);

    print(b, "b: ");*/

    /*
    std::vector<BlockVector> g(2);
    g[0].resize(f[0].size());
    g[1].resize(f[1].size());

    g[1][6] = 2;
    BlockVector transformedG;
    nBodyAssembler.nodalToTransformed(g, transformedG);

    print(g[1], "g1: ");
    print(transformedG, "transformedG: ");

    // merged gridGlue coupling boundary
    auto& gridGlue = *contactCoupling.getGlue();

    // make basis grid functions
    auto&& gridFunction0 = Functions::makeFunction(basis0, f[0]); */

   /* for (const auto& bIs : intersections(upper)) {
        const auto& inside = bIs.inside();
        const auto& bGeometry = bIs.geometry();

        for (size_t i=0; i<bGeometry.corners(); i++) {


            typename BasisGridFunction<Basis, BlockVector>::RangeType y;
            gridFunction1.evaluateLocal(outside, outGeometry.local(rGeometry.corner(i)), y);
            print(rGeometry.corner(i), "corner " + std::to_string(i));
            std::cout << "f1(corner) = " << y << std::endl;
        }
    }*/

    /*
    auto&& gridFunction1 = Functions::makeFunction(basis1, f[1]);

    for (const auto& rIs : intersections(gridGlue)) {
        const auto& inside = rIs.inside();
        const auto& outside = rIs.outside();

        const auto& rGeometry = rIs.geometry();
        const auto& outGeometry = outside.geometry();

        for (size_t i=0; i<rGeometry.corners(); i++) {
            typename BasisGridFunction<Basis, BlockVector>::RangeType y;
            gridFunction1.evaluateLocal(outside, outGeometry.local(rGeometry.corner(i)), y);
            print(rGeometry.corner(i), "corner " + std::to_string(i));
            std::cout << "f1(corner) = " << y << std::endl;

            std::cout << std::endl;
        }
        std::cout << "---------"<< std::endl;
    }*/

    for (const auto& elem : elements(gridView0)) {
        std::cout << "seed element corners:" << std::endl;

        const auto& eRefElement = Dune::ReferenceElements<double, dim>::general(elem.type());
        size_t vertices = eRefElement.size(dim);

        for (size_t j=0; j<vertices; j++) {
            print(elem.geometry().corner(j), "corner: ");
        }


        for (auto&& is : intersections(gridView0, elem)) {
            const auto& isGeo = is.geometry();

            if (!is.neighbor())
                continue;

            std::cout << "neighbor corners:" << std::endl;

            const auto& outside = is.outside();

            const auto& refElement = Dune::ReferenceElements<double, dim>::general(outside.type());
            size_t nVertices = refElement.size(dim);

            const auto& isRefElement = Dune::ReferenceElements<double, dim-1>::general(isGeo.type());

            for (size_t j=0; j<nVertices; j++) {
                print(outside.geometry().corner(j), "corner: ");
                const auto& local = elem.geometry().local(outside.geometry().corner(j));

                print(local, "local:");
            }

        }

        break;
    }


    bool passed = true;

    std::cout << "Overall, the test " << (passed ? "was successful!" : "failed!") << std::endl;

    std::cout.rdbuf(coutbuf); //reset to standard output again
    return passed ? 0 : 1;

} catch (Dune::Exception &e) {
    Dune::derr << "Dune reported error: " << e << std::endl;
} catch (std::exception &e) {
    std::cerr << "Standard exception: " << e.what() << std::endl;
} // end try
} // end main