#ifndef ROCK_FAULT_FACTORY_HH
#define ROCK_FAULT_FACTORY_HH

#include <math.h>
#include <queue>

#include <dune/faultnetworks/facehierarchy.hh>
#include <dune/faultnetworks/utils/debugutils.hh>
#include <dune/faultnetworks/faultfactories/oscunitcube.hh>
#include <dune/faultnetworks/levelinterfacenetwork.hh>
#include <dune/faultnetworks/interfacenetwork.hh>
#include <dune/faultnetworks/faultinterface.hh>
#include <dune/faultnetworks/hierarchicleveliterator.hh>

#include <dune/grid/common/mcmgmapper.hh>
#include <dune/grid/io/file/vtk/vtkwriter.hh>
#include <dune/grid/uggrid.hh>

#include <dune/fufem/boundarypatch.hh>
#include <dune/fufem/boundaryiterator.hh>

template<class ctype = double> class Rock {
public:
    int level;

    // vertex idx
    int left;
    int right;
    int top;
    int bottom;

    Dune::FieldVector<ctype, 2> center;

    void set(int _left, int _right, int _top, int _bottom, const Dune::FieldVector<ctype, 2>& _center) {
        left = _left;
        right = _right;
        top = _top;
        bottom = _bottom;

        center = _center;
    }
};


template <class GridType>
class LevelRockFaultFactory
{
    //! Parameter for mapper class
    template<int dim>
    struct FaceMapperLayout
    {
        bool contains (Dune::GeometryType gt)
        {
            return gt.dim() == dim-1;
        }
    };

protected:
    static const int dimworld = GridType::dimensionworld;
    static const int dim = GridType::dimension;
    using ctype = typename GridType::ctype;

    using Coords = typename Dune::FieldVector<ctype, dimworld>;
    using GV = typename GridType::LevelGridView;
    using Intersection = typename GridType::LevelIntersection;

    using Element = typename GridType::template Codim<0>::Entity;
    static const int dimElement = Element::dimension;

    using FaceMapper = typename Dune::MultipleCodimMultipleGeomTypeMapper<GV, FaceMapperLayout >;


    using MyRock = Rock<ctype>;

    const std::shared_ptr<GridType> grid_;
    const int level_;
    const ctype resolution_;
    const GV gridView_;

    const double splittingThreshold_;
    const double maxAngle_;

    const typename GV::IndexSet& indexSet_;

    FaceMapper faceMapper_;

    std::vector<Intersection> faces_;
    std::vector<Coords> vertexPositions_;

    using ID = std::array<size_t, 2>;
    std::vector<ID> vertexIDs_;
    std::map<ID, size_t> IDsToDof_;

    std::vector<std::vector<size_t>> vertexToFaces_;
    std::vector<int> coarseToLevelVertex_;

    const LevelRockFaultFactory& coarseLevelFactory_;
    std::vector<MyRock> rocks_;
    std::vector<std::shared_ptr<FaultInterface<GV>>>& faults_;

private:
    template <typename T> int sgn(T val) {
        return (T(0) < val) - (val < T(0));
    }

        bool intersectionAllowed(const size_t faceID, const size_t vertexID,
                                 const std::array<size_t, 2>& centerIDs,
                                 const std::set<size_t>& deadendDofs,
                                 const std::set<int>& separatingIDs, const std::set<size_t>& faultDofs,
                                 const Coords& direction, size_t dim) const {

            const auto& intersection = faces_[faceID];

            //check if "back" edge, cannot deviate from desiredOrientation more than maxAngle degrees (radian)
            auto orientation = intersection.geometry().center();
            orientation -= vertexPositions_[vertexID];
            orientation /= orientation.two_norm();

            if (std::acos(direction*orientation) > maxAngle_)
                return false;

            // check if intersection has separating dofs or other fault dofs
            std::set<size_t> intersectionDofs;
            computeIntersectionDofs(intersection, intersectionDofs);
            intersectionDofs.erase(vertexID);

            const size_t otherDim = (dim + 1) % 2;
            for (const auto& isDof : intersectionDofs) {
                const auto& vertex = vertexPositions_[isDof];
                std::array<size_t, 2> IDs = {computeID(vertex, 0), computeID(vertex, 1)};

                bool centerPassed = (sgn(direction[dim])>0) ? (centerIDs[dim]<IDs[dim]) : (centerIDs[dim]>IDs[dim]);
                if (faultDofs.count(isDof) or separatingIDs.count(IDs[otherDim]) or centerPassed or deadendDofs.count(isDof)) {
                    return false;
                }
            }

            return true;
        }

        /*
        ctype distance(const Intersection& isec, const Dune::FieldVector<ctype, dimworld> & vertex) const {
            Dune::FieldVector<ctype, dimworld> vec(vertex);
            vec += isec.center();

            return isec.unitOuterNormal()*vec;
        }*/

        //works only in 2D, vertex(1) = vertex(0) - yid, intersection of line (given by vertex and derivative 1) with y axis
        int computeID(const Coords& vertex, int dim) const {
            return (int) (vertex[dim]*resolution_);
        }

        bool isTargetBoundaryVertex(size_t vertexIdx, const std::array<size_t, 2>& targetIDs) const {
            const auto& vertex = vertexPositions_[vertexIdx];
            const std::array<size_t, 2> vertexIDs = {computeID(vertex, 0), computeID(vertex, 1)};
            return (vertexIDs[0] == targetIDs[0]) and (vertexIDs[1] == targetIDs[1]);
        }

        typedef std::vector<double>::const_iterator ConstVectorIterator;

        struct ordering {
            bool operator ()(std::pair<size_t, ConstVectorIterator> const& a, std::pair<size_t, ConstVectorIterator> const& b) {
                return *(a.second) < *(b.second);
            }
        };

        void distanceSort(std::vector<size_t>& admissibleFaces, std::vector<double>& distances) const {
            std::vector<std::pair<size_t, ConstVectorIterator>> order(distances.size());

            size_t j = 0;
            ConstVectorIterator it = distances.begin();
            ConstVectorIterator itEnd = distances.end();
            for (; it != itEnd; ++it, ++j)
                order[j] = std::make_pair(j, it);

            sort(order.begin(), order.end(), ordering());

            std::vector<size_t> initialAdmissibleFaces(admissibleFaces);
            for (size_t i=0; i<admissibleFaces.size(); i++)
                admissibleFaces[i] = initialAdmissibleFaces[order[i].first];
        }

       /* void generateFaultSeeds(std::vector<FaultCorridor>& faultCorridors, std::vector<std::vector<size_t>>& faultSeeds) {

            //std::cout << "LevelGeoFaultFactory::generateFaultSeeds() " << std::endl;
            //std::cout << "------------------------------------- " << std::endl << std::endl;

            faultSeeds.resize(faultCorridors.size());

            std::map<int, std::pair<size_t,size_t>> yIDtoFaultCorridor;
            for (size_t i=0; i<faultCorridors.size(); i++){
                FaultCorridor& faultCorridor = faultCorridors[i];
                const std::vector<int> & faultYs = faultCorridor.faultYs();

                faultSeeds[i].resize(faultYs.size());
                for (size_t j=0; j<faultYs.size(); j++) {
                    yIDtoFaultCorridor[faultYs[j]] = std::make_pair(i, j);
                }
            }

            BoundaryIterator<GV> bIt(gridView_, BoundaryIterator<GV>::begin);
            BoundaryIterator<GV> bEnd(gridView_, BoundaryIterator<GV>::end);
            for(; bIt!=bEnd; ++bIt) {

                const Element& insideElement = bIt->inside();
                const auto& geometry = insideElement.geometry();

                // neglect "upper" part of boundary, where (y==1 and x>0) or (y>0 and x==1)
                const auto& center = bIt->geometry().center();
                if (center[0] + center[1] >1)
                        continue;

                const auto& refElement = Dune::ReferenceElements<double,dimElement>::general(insideElement.type());

                for (int i=0; i<refElement.size(bIt->indexInInside(), 1, dimElement); i++) {
                    size_t idxInElement = refElement.subEntity(bIt->indexInInside(), 1, i, dimElement);
                    size_t globalIdx = indexSet_.subIndex(insideElement, idxInElement, dimElement);

                    const auto& vertex = geometry.corner(idxInElement);

                    const int yID = computeYID(vertex);

                    if (yIDtoFaultCorridor.count(yID)) {
                        const std::pair<size_t, size_t>& indices = yIDtoFaultCorridor[yID];
                        faultSeeds[indices.first][indices.second] = globalIdx;
                    }
                }
            }

            //std::cout << "------------------------------------- " << std::endl << std::endl;
        }*/


        void generateFault(std::shared_ptr<FaultInterface<GV>> fault, const size_t faultSeedIdx, const Coords& center,
                           const size_t corridor,
                           const std::set<size_t>& separatingDofs) {
            //std::cout << "LevelGeoFaultFactory::generateFault() " << std::endl;
            //std::cout << "------------------------------------- " << std::endl << std::endl;

            bool success = false;

            std::set<size_t> deadendDofs = separatingDofs;

            auto direction = center;
            direction -= vertexPositions_[faultSeedIdx];
            direction /= direction.two_norm();

            size_t dim = 0;
            if (std::abs(direction[0]) < std::abs(direction[1])) {
                dim = 1;
            }

            const ID centerIDs = {computeID(center, 0), computeID(center, 1)};

            std::set<size_t> separatingIDs;

            auto faultSeedID = vertexIDs_[faultSeedIdx][dim];
            separatingIDs.insert(faultSeedID + corridor);
            separatingIDs.insert(faultSeedID - corridor);

            std::vector<size_t> faultDofs(0);
            faultDofs.push_back(faultSeedIdx);

            std::vector<size_t> faultFaces;

            std::map<size_t, std::vector<size_t>> vertexToAdmissibleFaces;

            std::queue<size_t> vertexQueue;
            vertexQueue.push(faultSeedIdx);

            while (!vertexQueue.empty()) {
                const size_t vertexID = vertexQueue.front();
                vertexQueue.pop();

                if (isTargetBoundaryVertex(vertexID, centerIDs)) {
                    success = true;
                    break;
                }

                if (vertexToAdmissibleFaces.count(vertexID)==0) {
                    const std::vector<size_t>& faces = vertexToFaces_[vertexID];
                    std::vector<size_t> admissibleFaces;

                    for (size_t i=0; i<faces.size(); i++) {
                        if (intersectionAllowed(faces[i], vertexID, centerIDs, deadendDofs, separatingIDs, faultDofs, direction, dim)) {
                            admissibleFaces.push_back(faces[i]);
                        }
                    }

                    std::vector<double> distances(admissibleFaces.size(), 0);
                    for (size_t i=0; i<admissibleFaces.size(); i++) {
                        Coords vec(center);
                        vec -= faces_[admissibleFaces[i]].geometry().center();
                        distances[i] = vec.two_norm();
                    }
                    distanceSort(admissibleFaces, distances);

                    vertexToAdmissibleFaces[vertexID] = admissibleFaces;
                }

                std::vector<size_t>& suitableFaces = vertexToAdmissibleFaces[vertexID];

                if (suitableFaces.size()==0) {
                    faultDofs.pop_back();
                    faultFaces.pop_back();
                    vertexToAdmissibleFaces.erase(vertexID);

                    deadendDofs.insert(vertexID);
                    vertexQueue.push(faultDofs.back());
                } else {

                    // generate random number from (0,1)
                    double randomNumber = ((double) std::rand() / (RAND_MAX));

                    size_t randSelector = (size_t) std::abs(std::log2(randomNumber));
                    if (randSelector >= suitableFaces.size()) {
                        randSelector = suitableFaces.size() -1;
                    }

                    size_t nextFaceID = suitableFaces[randSelector];

                    suitableFaces.erase(suitableFaces.begin()+randSelector);

                    std::set<size_t> nextFaceDofs;
                    computeIntersectionDofs(faces_[nextFaceID], nextFaceDofs);
                    nextFaceDofs.erase(vertexID);

                    faultDofs.push_back(*nextFaceDofs.begin());
                    faultFaces.push_back(nextFaceID);

                    size_t nextVertexID = *(nextFaceDofs.begin());
                    vertexQueue.push(nextVertexID);
                }
            }

            if (!success) {
                std::cout << "Generating a fault failed: Unable to reach target boundary! This should not happend!" << std::endl;
                DUNE_THROW(Dune::Exception, "Generating a fault failed: Unable to reach target boundary!");
            }

            for (size_t i=0; i<faultFaces.size(); i++) {
                fault->addFace(faces_[faultFaces[i]]);
            }

            //std::cout << "------------------------------------- " << std::endl << std::endl;
        }

    auto searchDof(const ID& IDs, const std::set<size_t>& separatingDofs, size_t dim, int dir) {
        auto candidatIDs = IDs;
        int lastDof = -1;

        if (dir>0) {
            while (lastDof<0) {
                candidatIDs[dim]++;
                auto dof = IDsToDof_[candidatIDs];
                if (separatingDofs.count(dof)) {
                    lastDof = dof;
                }
            }
        } else {
            while (lastDof<0) {
                candidatIDs[dim]--;
                auto dof = IDsToDof_[candidatIDs];
                if (separatingDofs.count(dof)) {
                    lastDof = dof;
                }
            }
        }

        return lastDof;
    }

    void createRock(MyRock& rock, const Coords& center, const std::set<size_t>& separatingDofs,
                    const std::set<size_t>& xFaultDofs, bool xFaultBottom,
                    const std::set<size_t>& yFaultDofs, bool yFaultRight) {

        rock.level = level_;
        rock.center = center;

        const ID centerIDs = {computeID(center, 0), computeID(center, 1)};

        size_t left, right, top, bottom = 0;

        if (xFaultBottom) {
            top = searchDof(centerIDs, separatingDofs, 1, 1);
            bottom = searchDof(centerIDs, xFaultDofs, 1, -1);
        } else {
            top = searchDof(centerIDs, xFaultDofs, 1, 1);
            bottom = searchDof(centerIDs, separatingDofs, 1, -1);
        }

        if (yFaultRight) {
            left = searchDof(centerIDs, separatingDofs, 0, -1);
            right = searchDof(centerIDs, yFaultDofs, 0, 1);
        } else {
            left = searchDof(centerIDs, yFaultDofs, 0, -1);
            right = searchDof(centerIDs, separatingDofs, 0, 1);
        }

        rock.set(left, right, top, bottom, center);
    }

    void prolong(const MyRock& rock, MyRock& newRock) {
        newRock.level = rock.level;

        auto newLeft = coarseToLevelVertex_[rock.left()];
        auto newRight = coarseToLevelVertex_[rock.right()];
        auto newTop = coarseToLevelVertex_[rock.top()];
        auto newBottom = coarseToLevelVertex_[rock.bottom()];

        newRock.set(newLeft, newRight, newTop, newBottom, rock.center());
    }

    bool randomSplit(const MyRock& rock) {
        const auto& center = rock.center();

        double res = 0.0;
        if (center[0] < center[1]) {
            res = 1.0 - center[0];
        } else {
            res = 1.0 - center[1];
        }

        if (res > splittingThreshold_)
            return true;

        double prob = (100.0 * std::rand() / (RAND_MAX + 1.0)) + 1;
        return  100.0*res > prob;
    }

    void split(const MyRock& rock, const std::set<size_t>& separatingDofs) {
        Rock newRock;
        prolong(rock, newRock);

        bool toBeSplit = (rock.level() == coarseLevelFactory_.level()) and randomSplit(rock);
        if (!toBeSplit) {
            rocks_.push_back(newRock);
        } else {
            const ID centerIDs = {computeID(newRock.center, 0), computeID(newRock.center, 1)};

            size_t xCorridor = 1.0/2 * std::min(centerIDs[0] - vertexIDs_[newRock.left][0],
                    vertexIDs_[newRock.right][0] - centerIDs[0]) + 1;
            size_t yCorridor = 1.0/2 * std::min(centerIDs[1] - vertexIDs_[newRock.bottom][1],
                    vertexIDs_[newRock.top][1] - centerIDs[1]) + 1;

            // split rock into 4 subparts by 4 new faults intersecting at center of rock
            std::shared_ptr<FaultInterface<GV>> x1Fault = std::make_shared<FaultInterface<GV>>(gridView_, level_);
            generateFault(x1Fault, newRock.left, newRock.center, yCorridor, separatingDofs);
            faults_.push_back(x1Fault);

            std::shared_ptr<FaultInterface<GV>> x2Fault = std::make_shared<FaultInterface<GV>>(gridView_, level_);
            generateFault(x2Fault, newRock.right, newRock.center, yCorridor, separatingDofs);
            faults_.push_back(x2Fault);

            std::shared_ptr<FaultInterface<GV>> y1Fault = std::make_shared<FaultInterface<GV>>(gridView_, level_);
            generateFault(y1Fault, newRock.bottom, newRock.center, xCorridor, separatingDofs);
            faults_.push_back(y1Fault);

            std::shared_ptr<FaultInterface<GV>> y2Fault = std::make_shared<FaultInterface<GV>>(gridView_, level_);
            generateFault(y2Fault, newRock.top, newRock.center, xCorridor, separatingDofs);
            faults_.push_back(y2Fault);

            auto left = vertexPositions_[newRock.left];
            auto right = vertexPositions_[newRock.right];
            auto top = vertexPositions_[newRock.top];
            auto bottom = vertexPositions_[newRock.bottom];

            MyRock rock00;
            auto center00 = 1.0/2*(right + bottom);
            createRock(rock00, center00, separatingDofs, x2Fault.getInterfaceDofs(), 0, y1Fault.getInterfaceDofs(), 0);
            rocks_.push_back(rock00);

            MyRock rock01;
            auto center01 = 1.0/2*(left + bottom);
            createRock(rock01, center01, separatingDofs, x1Fault.getInterfaceDofs(), 0, y1Fault.getInterfaceDofs(), 1);
            rocks_.push_back(rock01);

            MyRock rock10;
            auto center10 = 1.0/2*(right + top);
            createRock(rock10, center10, separatingDofs, x2Fault.getInterfaceDofs(), 1, y2Fault.getInterfaceDofs(), 0);
            rocks_.push_back(rock10);

            MyRock rock11;
            auto center11 = 1.0/2*(left + top);
            createRock(rock11, center11, separatingDofs, x1Fault.getInterfaceDofs(), 1, y2Fault.getInterfaceDofs(), 1);
            rocks_.push_back(rock11);
        }
    }

    public:
        LevelRockFaultFactory(const std::shared_ptr<GridType> grid, const int level, const ctype resolution,
                              const LevelRockFaultFactory& coarseLevelFactory,
                              const double splittingThreshold = 0.0, const double maxAngle = 2) :
            grid_(grid),
            level_(level),
            resolution_(resolution),
            gridView_(grid_->levelGridView(level_)),
            splittingThreshold_(splittingThreshold),
            maxAngle_(maxAngle),
            indexSet_(gridView_.indexSet()),
            faceMapper_(gridView_),
            coarseLevelFactory_(coarseLevelFactory)
            {
                // init faces_, vertexPositions_, vertexToFaces_
                faces_.resize(faceMapper_.size());
                vertexPositions_.resize(gridView_.size(dim));
                vertexToFaces_.resize(gridView_.size(dim));
                vertexIDs_.resize(gridView_.size(dim));
                coarseToLevelVertex_.resize(coarseLevelFactory_.gridView().size(dim));

                std::vector<bool> faceHandled(faceMapper_.size(), false);

                for (const auto& elem:elements(gridView_)) {
                    for (const auto& isect:intersections(gridView_, elem)) {
                        //const size_t faceID = faceMapper_.index(isect);

                        const int faceID = faceMapper_.subIndex(elem, isect.indexInInside(), 1);

                        if (isect.boundary())
                            continue;

                        if (faceHandled[faceID])
                            continue;

                        faceHandled[faceID] = true;
                        faces_[faceID] = isect;

                        const auto& refElement = Dune::ReferenceElements<double,dimElement>::general(elem.type());

                        for (int i=0; i<refElement.size(isect.indexInInside(), 1, dimElement); i++) {
                            //size_t vertexID = idxSet.subIndex(elem.subEntity<1>(isect), i, dimElement);

                            size_t idxInElement = refElement.subEntity(isect.indexInInside(), 1, i, dimElement);
                            const size_t vertexID = indexSet_.subIndex(elem, idxInElement, dimElement);
                            const auto& vertex = elem.geometry().corner(idxInElement);

                            vertexPositions_[vertexID] = vertex;
                            vertexToFaces_[vertexID].push_back(faceID);

                            ID id = {computeID(vertex, 0), computeID(vertex, 1)};
                            vertexIDs_[vertexID] = id;
                            IDsToDof_[id] = vertexID;
                        }
                    }
                }

                for (const auto& vertex : vertices(gridView_)) {
                    size_t coarseVertexID;
                    bool isCoarseVertex = coarseLevelFactory_.indexSet().index(vertex, coarseVertexID);

                    if (isCoarseVertex) {
                        coarseToLevelVertex_[coarseVertexID] = indexSet_.index(vertex);
                    }
                }
            }

        void build(const std::set<size_t>& separatingDofs) {
            faults_.resize(0);
            const auto& coarseRocks = coarseLevelFactory_.rocks();

            for (size_t i=0; i<coarseRocks.size(); i++) {
                split(coarseRocks[i], separatingDofs);
            }
        }

        void computeIntersectionDofs(const Intersection& intersection, std::set<size_t>& intersectionDofs) const {
            intersectionDofs.clear();

            // loop over all vertices of the intersection
            const auto& insideElement = intersection.inside();
            const auto& refElement = Dune::ReferenceElements<double,dimElement>::general(insideElement.type());

            for (int i=0; i<refElement.size(intersection.indexInInside(), 1, dimElement); i++) {
                size_t idxInElement = refElement.subEntity(intersection.indexInInside(), 1, i, dimElement);
                size_t globalIdx = indexSet_.subIndex(insideElement, idxInElement, dimElement);

                intersectionDofs.insert(globalIdx);
            }
        }

        const auto& gridView() const {
            return gridView_;
        }

        const auto& indexSet() const {
            return indexSet_;
        }

        const auto& rocks() const {
            return rocks_;
        }

        auto level() const {
            return level_;
        }

        const auto& faults() const {
            return faults_;
        }
};

template <class GridType>
class InitLevelRockFaultFactory : public LevelRockFaultFactory<GridType> {
private:
    using Base = LevelRockFaultFactory<GridType>;
public:
    InitLevelRockFaultFactory(const std::shared_ptr<GridType> grid, const int level, const typename Base::ctype resolution) :
        Base(grid, level, resolution, nullptr) {}

    void build(const std::set<size_t>& boundaryDofs) {
        typename Base::MyRock rock;

        rock.level = this->level_;
        rock.center[0] = 0.5;
        rock.center[1] = 0.5;

        const typename Base::ID centerIDs = {this->computeID(rock.center, 0), this->computeID(rock.center, 1)};

        size_t top = this->searchDof(centerIDs, boundaryDofs, 1, 1);
        size_t bottom = this->searchDof(centerIDs, boundaryDofs, 1, -1);

        size_t left = this->searchDof(centerIDs, boundaryDofs, 0, -1);
        size_t right = this->searchDof(centerIDs, boundaryDofs, 0, 1);

        rock.set(left, right, top, bottom, rock.center);

        this->rocks_.push_back(rock);
    }

};

template <class GridType>
class RockFaultFactory {

private:
    const int coarseResolution_;
    const size_t maxLevel_;
    const int coarseGridN_;

    std::vector<double> levelResolutions_;

    std::shared_ptr<GridType> grid_;
    std::shared_ptr<InterfaceNetwork<GridType>> interfaceNetwork_;

    std::vector<std::shared_ptr<LevelRockFaultFactory<GridType>>> levelRockFaultFactories_;

/*typename std::enable_if<!std::numeric_limits<ctype>::is_integer, bool>::type
            almost_equal(ctype x, ctype y, int ulp) const
        {
            return std::abs(x-y) < std::numeric_limits<ctype>::epsilon() * std::abs(x+y) * ulp
                   || std::abs(x-y) < std::numeric_limits<ctype>::min();
        }



bool isSeparatingIntersection(const Intersection& intersection, const std::set<int>& separatingYIDs) const {
    Dune::FieldVector<ctype, dimworld> faceCenter = intersection.geometry().center();

    ctype yID = faceCenter[1]*coarseGridN_;
    ctype intpart;

    if (almost_equal(std::modf(yID, &intpart), 0.0, 2)) {
        std::set<int>::iterator it = separatingYIDs.find((int) intpart);
        if (it!=separatingYIDs.end())
            return true;
        else
            return false;
    } else
        return false;
}*/


public:
        //setup 
    RockFaultFactory(const int coarseResolution, const size_t maxLevel, const double maxAngle = 2) :
        coarseResolution_(coarseResolution),
        maxLevel_(maxLevel),
        coarseGridN_(std::pow(2, coarseResolution_)),
        interfaceNetwork_(nullptr)
        {   
            using GridOb = OscUnitCube<GridType, 2>;
            using GV = typename GridType::LevelGridView;

            Dune::UGGrid<GridType::dimension>::setDefaultHeapSize(4000);
            GridOb unitCube(coarseGridN_);
            grid_ = unitCube.grid();
            grid_->globalRefine(maxLevel_);

            levelResolutions_.resize(maxLevel_+1);
            levelRockFaultFactories_.resize(maxLevel_+1);

            // init interface network
            interfaceNetwork_ = std::make_shared<InterfaceNetwork<GridType>>(*grid_);

            // init level 0 rockFaultFactory
            levelResolutions_[0] = std::pow(2, coarseResolution_);

            std::set<size_t> boundaryDofs;
            BoundaryIterator<GV> bIt(levelRockFaultFactories_[0].gridView(), BoundaryIterator<GV>::begin);
            BoundaryIterator<GV> bEnd(levelRockFaultFactories_[0].gridView(), BoundaryIterator<GV>::end);
            for(; bIt!=bEnd; ++bIt) {
                std::set<size_t> intersectionDofs;
                levelRockFaultFactories_[0].computeIntersectionDofs(*bIt, intersectionDofs);
                boundaryDofs.insert(intersectionDofs.begin(), intersectionDofs.end());
            }

            InitLevelRockFaultFactory initFactory(grid_, 0, levelResolutions_[0]);
            initFactory.build(boundaryDofs);

            levelRockFaultFactories_[0] = std::make_shared<LevelRockFaultFactory<GridType>>(grid_, 0, levelResolutions_[0], initFactory, 1.0);
            levelRockFaultFactories_[0]->build(boundaryDofs);

            const auto& faults = levelRockFaultFactories_[0]->faults();
            for (size_t j=0; j<faults.size(); j++) {
                interfaceNetwork_->addInterface(faults[j]);
            }
            interfaceNetwork_->prolongLevel(0, 1);

            for (size_t i=1; i<=maxLevel_; i++) {
                levelResolutions_[i] = std::pow(2, coarseResolution_+i);

                //generate faults on level
                levelRockFaultFactories_[i] = std::make_shared<LevelRockFaultFactory<GridType>>(grid_, i, levelResolutions_[i], *levelRockFaultFactories_[i-1], (i==1)*0.5);
                levelRockFaultFactories_[i]->build(interfaceNetwork_->getInterfaceNetworkDofs(i));

                faults = levelRockFaultFactories_[i]->faults();
                for (size_t j=0; j<faults.size(); j++) {
                    interfaceNetwork_->addInterface(faults[j]);
                }

                if (i==maxLevel_)
                    continue;

                interfaceNetwork_->prolongLevel(i, i+1);
            }
        }

    /*
    void prolongToAll() {
        // prolong all faults to all subsequent levels
        for (int i=maxLevel_-1; i>=0; i--) {
            if (interfaceNetwork_->size(i)>0) {
                std::set<int> toLevels;
                for (size_t j=i+1; j<=maxLevel_; j++) {
                    toLevels.insert(j);
                }

                interfaceNetwork_->prolongLevelInterfaces(i, toLevels);
            }
        }
        interfaceNetwork_->build();
    }*/

    /*void prolongToAll() {
        // prolong all faults to all subsequent levels
        for (int i=interfaceNetwork_->size()-1; i>=0; i--) {
            interfaceNetwork_->prolongLevelInterfaces(i, maxLevel_);
        }
        interfaceNetwork_->build();
    }*/

    const GridType& grid() const {
        return *grid_;
	}
	
    /*const InterfaceNetwork<GridType>& interfaceNetwork() {
        return *interfaceNetwork_;
    }*/

    InterfaceNetwork<GridType>& interfaceNetwork() {
        return *interfaceNetwork_;
    }
};
#endif