Skip to content
Snippets Groups Projects
faultp1nodalbasis.hh 10.4 KiB
Newer Older
podlesny's avatar
podlesny committed
#ifndef FAULT_P1_NODALBASIS_HH
#define FAULT_P1_NODALBASIS_HH

/**
   @file
   @brief

   @author
 */

#include <queue>

#include <dune/localfunctions/lagrange/pqkfactory.hh>
#include <dune/fufem/functionspacebases/functionspacebasis.hh>
#include <dune/fufem/functionspacebases/p1nodalbasis.hh>
#include <dune/istl/matrixindexset.hh>

#include <dune/faultnetworks/levelinterfacenetwork.hh>

podlesny's avatar
.  
podlesny committed
#include <dune/faultnetworks/utils/debugutils.hh>

podlesny's avatar
podlesny committed
template <class GV, class RT=double>
class FaultP1NodalBasis :
    public P1NodalBasis<
        GV,
        RT>
{
    protected:
        typedef typename GV::Grid GridType;
        typedef typename GridType::ctype ctype;
        static const int dimGrid = GridType::dimension;
        typedef typename GridType::LevelIntersection Intersection;
        typedef typename GV::IntersectionIterator IntersectionIterator;

        typedef P1NodalBasis<GV, RT> Base;
        typedef typename Base::Element Element;
        static const int dimElement = Element::dimension;

        using Base::dim;
        using Base::gridview_;
        using Base::cache_;


        typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<RT, 1, 1>> MatrixType;

        MatrixType prolongationMatrix_;
        MatrixType restrictionMatrix_;

        const LevelInterfaceNetwork<GV>& faultNetwork_;

        // (elemID, vertexID) -> index
        std::map<std::pair<size_t, size_t>, int> indexMapper_;
        size_t maxDofIdx_;

    private:
        size_t elementIndex(const Element& elem) const {
            return gridview_.indexSet().index(elem);
        }

        void computeIntersectionDofs(const typename GV::IndexSet& idxSet, const Intersection& intersection, std::set<size_t>& intersectionDofs) {
            intersectionDofs.clear();

            // loop over all vertices of the intersection
            const Element& insideElement = intersection.inside();

            const auto& refElement = Dune::ReferenceElements<double,dimElement>::general(insideElement.type());
            for (int i=0; i<refElement.size(intersection.indexInInside(), 1, dimElement); i++) {
                size_t idxInElement = refElement.subEntity(intersection.indexInInside(), 1, i, dimElement);
                size_t globalIdx = idxSet.subIndex(insideElement, idxInElement, dimElement);

                intersectionDofs.insert(globalIdx);
            }
        }

        void buildVertexDofs(const typename GV::IndexSet& idxSet, const Intersection& intersection, size_t vertex) {
                const auto& insideElem = intersection.inside();

                std::set<size_t> visited;

                std::queue<Element> subpatchSeeds;
                subpatchSeeds.push(insideElem);
                subpatchSeeds.push(intersection.outside());

                size_t dofIdx = vertex;
                bool firstSubpatch = true;

                while (!subpatchSeeds.empty()) {
                    const Element& subpatchSeed = subpatchSeeds.front();
                    subpatchSeeds.pop();

                    if (visited.count(elementIndex(subpatchSeed)))
                        continue;

                    if (firstSubpatch) {
                        firstSubpatch = false;
                    } else {
                        dofIdx = ++maxDofIdx_;
                    }

                    std::queue<Element> elemQueue;
                    elemQueue.push(subpatchSeed);

                    while (!elemQueue.empty()) {
                        const Element& elem = elemQueue.front();
                        elemQueue.pop();

                        size_t elemIdx = elementIndex(elem);

                        if (visited.count(elemIdx))
                            continue;

                        visited.insert(elemIdx);
                        indexMapper_[std::make_pair(elemIdx, vertex)] = dofIdx;

                        // iterate over elem intersections
                        IntersectionIterator it = gridview_.ibegin(elem);
                        IntersectionIterator endIt = gridview_.iend(elem);
                        for (; it != endIt; ++it){
                            if (it->neighbor()) {
                                const Element& neighbor = it->outside();

                                std::set<size_t> intersectionDofs;
                                computeIntersectionDofs(idxSet, *it, intersectionDofs);

                                if (intersectionDofs.count(vertex)) {
                                    if (faultNetwork_.isInterfaceIntersection(*it)) {
                                        subpatchSeeds.push(neighbor);
                                    } else {
                                        elemQueue.push(neighbor);
                                    }
                                }
                            }
                        }
                    }
                }
        }

    public:
        typedef typename Base::GridView GridView;
        typedef typename Base::ReturnType ReturnType;
        typedef typename Dune::PQkLocalFiniteElementCache<ctype, RT, GV::dimension, 1> FiniteElementCache;
        typedef typename FiniteElementCache::FiniteElementType LFE;
        typedef typename Base::LocalFiniteElement LocalFiniteElement;
        typedef typename Base::LinearCombination LinearCombination;

        FaultP1NodalBasis(const LevelInterfaceNetwork<GV>& faultNetwork) :
            Base(faultNetwork.levelGridView()),
            faultNetwork_(faultNetwork),
            maxDofIdx_(gridview_.indexSet().size(dimGrid)-1)
        {
            // compute dof indices for vertices located on the fault network
            const typename GV::IndexSet& idxSet = gridview_.indexSet();

            std::set<size_t> visited;

            const std::vector<Intersection>&  faultIntersections = faultNetwork_.getIntersections();
            for (size_t i=0; i<faultIntersections.size(); i++) {
                const Intersection& faultIntersection = faultIntersections[i];

                std::set<size_t> intersectionDofs;
                computeIntersectionDofs(idxSet, faultIntersection, intersectionDofs);

                //print(intersectionDofs, "intersectionDofs: ");

                std::set<size_t>::const_iterator it = intersectionDofs.begin();
                std::set<size_t>::const_iterator endIt = intersectionDofs.end();
                for (; it!=endIt; ++it){
                    size_t vertex = *it;

                    if (!visited.count(vertex)) {
                        buildVertexDofs(idxSet, faultIntersection, vertex);
                        visited.insert(vertex);
                    }
                }
            }

            // TODO: remove after debugging
            /*std::cout << "FaultP1Nodalbasis: indexMapper, size: " << indexMapper_.size() << std::endl;

            std::map<std::pair<size_t, size_t>, int>::const_iterator mapIt = indexMapper_.begin();
            std::map<std::pair<size_t, size_t>, int>::const_iterator endMapIt = indexMapper_.end();
            for (; mapIt!=endMapIt; ++mapIt) {
                std::cout << "elemIdx: " << mapIt->first.first << " , vertexIdx: " << mapIt->first.second << " : " << mapIt->second << std::endl;
            }
            std::cout << "FaultP1Nodalbasis: size " << this->size() << std::endl; */
            // ---- end remove ------------


            // compute prolongation and restriction matrices
            const std::set<size_t>& interfaceNetworkDofs = faultNetwork_.getInterfaceNetworkDofs();

            const size_t conDim = gridview_.indexSet().size(dimGrid);
            const size_t disconDim = conDim + faultNetwork_.dofCount();

            Dune::MatrixIndexSet prolongIdxSet(disconDim, conDim);
            Dune::MatrixIndexSet restrictIdxSet(conDim, disconDim);

            // compute respective index sets
            for (size_t i=0; i<conDim; i++) {
                prolongIdxSet.add(i, i);
                restrictIdxSet.add(i, i);
            }

            std::set<size_t>::iterator beginIt = interfaceNetworkDofs.begin();
            for (size_t i=conDim; i<disconDim; i++) {
                std::set<size_t>::iterator it = beginIt;
                std::advance(it, i-conDim);
                const size_t idx = *it;

                prolongIdxSet.add(i, idx);
                restrictIdxSet.add(idx, i);
            }

            // set entries of prolongation
            prolongIdxSet.exportIdx(prolongationMatrix_);
            prolongationMatrix_ = 1;

            // set entries of restriction
            restrictIdxSet.exportIdx(restrictionMatrix_);

            typedef typename MatrixType::row_type RowType;
            typedef typename RowType::ConstIterator ColumnIterator;

            for(size_t rowIdx = 0; rowIdx<conDim; rowIdx++) {
                RowType& row = restrictionMatrix_[rowIdx];

                RT entry = 1.0/row.size();

                ColumnIterator colIt = row.begin();
                ColumnIterator colEndIt = row.end();
                for(; colIt!=colEndIt; ++colIt) {
                    row[colIt.index()] = entry;
                }
            }
        }

        const LevelInterfaceNetwork<GV>& faultNetwork() const {
            return faultNetwork_;
        }

        size_t size() const
        {
            return maxDofIdx_+1;
        }

        const LocalFiniteElement& getLocalFiniteElement(const Element& e) const
        {
            return cache_.get(e.type());
        }

        int index(const Element& e, const int i) const
        {
            const size_t globalIdx = indexInGridView(e, i);
            const size_t elemIdx = elementIndex(e);

            const std::pair<size_t, size_t> idxPair = std::make_pair(elemIdx, globalIdx);

            if (indexMapper_.count(idxPair))
                return indexMapper_.at(idxPair);
            else
				return globalIdx;
        }

        int indexInGridView(const Element& e, const int i) const
        {
            return gridview_.indexSet().subIndex(e, getLocalFiniteElement(e).localCoefficients().localKey(i).subEntity(), dimGrid);
        }

        // prolong continuous p1 vector to discontinuous p1 vector
        template<class VectorType>
        void prolong(const VectorType& x, VectorType& res) {
            res.resize(prolongationMatrix_.N());

            prolongationMatrix_.mv(x, res);
        }

        // restrict discontinuous p1 vector to continuous p1 vector
        template<class VectorType>
        void restrict(const VectorType& x, VectorType& res) {
            res.resize(restrictionMatrix_.N());

            restrictionMatrix_.mv(x, res);
        }
};

#endif