#include "utils/almostequal.hh"

#include "nodalweights.hh"

template <class Basis0, class Basis1>
NodalWeights<Basis0, Basis1>::NodalWeights(const Basis0& basis0, const Basis1& basis1)
    : basis0_(basis0),
      basis1_(basis1)
    {}

template <class Basis0, class Basis1>
template <class Basis, class Element, class GlobalCoords>
auto NodalWeights<Basis0, Basis1>::basisDof(const Basis& basis, const Element& elem, const GlobalCoords& vertex) const {
    int dof = -1;

    const typename Basis::LocalFiniteElement& lFE = basis.getLocalFiniteElement(elem);
    const auto& geometry = elem.geometry();
    const auto& localVertex = geometry.local(vertex);

    const size_t localBasisSize = lFE.localBasis().size();
    std::vector<Dune::FieldVector<double, 1>, std::allocator<Dune::FieldVector<double, 1> > > localBasisRep(localBasisSize);
    lFE.localBasis().evaluateFunction(localVertex, localBasisRep);

    for(size_t i=0; i<localBasisSize; i++) {
        if (almost_equal(localBasisRep[i][0], 1.0, 2)) {
            dof = basis.index(elem, i);
            break;
        }
    }

    return dof;
}

template <class Basis0, class Basis1>
template <class GridGlue, class ScalarVector>
void NodalWeights<Basis0, Basis1>::assemble(const GridGlue& glue, ScalarVector& weights0, ScalarVector& weights1, bool initializeVector) const {
    using ctype = typename ScalarVector::field_type;

    if (initializeVector==true) {
        weights0.resize(basis0_.size());
        weights1.resize(basis1_.size());
    }

    // loop over all intersections
    for (const auto& rIs : intersections(glue)) {
        const auto& inside = rIs.inside();
        const auto& outside = rIs.outside();

        /*if (!nmBoundary.contains(inside, rIs.indexInInside())) {
            std::cout << "it happened" << std::endl;
            continue;
        }*/

        const auto& geometry = rIs.geometry();
        const ctype val = 1.0/(geometry.mydimension + 1)*geometry.volume();

        std::cout << geometry.mydimension  << " " << geometry.volume() << " " << val << std::endl;

        for (int i=0; i<geometry.corners(); i++) {
            const auto& vertex = geometry.corner(i);

            const int inIdx = basisDof(basis0_, inside, vertex);
            if (inIdx >= 0) {
                weights0[inIdx] += val;
            }

            const int outIdx = basisDof(basis1_, outside, vertex);
            if (outIdx >= 0) {
                weights1[outIdx] += val;
            }

            std::cout << inIdx << " " << outIdx << std::endl;
        }
    }
}