#ifndef FAULT_P1_NODALBASIS_HH #define FAULT_P1_NODALBASIS_HH /** @file @brief @author */ #include <queue> #include <dune/localfunctions/lagrange/pqkfactory.hh> #include <dune/fufem/functionspacebases/functionspacebasis.hh> #include <dune/fufem/functionspacebases/p1nodalbasis.hh> #include <dune/istl/matrixindexset.hh> #include <dune/faultnetworks/levelinterfacenetwork.hh> #include <dune/faultnetworks/utils/debugutils.hh> template <class GV, class RT=double> class FaultP1NodalBasis : public P1NodalBasis< GV, RT> { protected: typedef typename GV::Grid GridType; typedef typename GridType::ctype ctype; static const int dimGrid = GridType::dimension; typedef typename GridType::LevelIntersection Intersection; typedef typename GV::IntersectionIterator IntersectionIterator; typedef P1NodalBasis<GV, RT> Base; typedef typename Base::Element Element; static const int dimElement = Element::dimension; using Base::dim; using Base::gridview_; using Base::cache_; typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<RT, 1, 1>> MatrixType; MatrixType prolongationMatrix_; MatrixType restrictionMatrix_; const LevelInterfaceNetwork<GV>& faultNetwork_; // (elemID, vertexID) -> index std::map<std::pair<size_t, size_t>, int> indexMapper_; size_t maxDofIdx_; private: size_t elementIndex(const Element& elem) const { return gridview_.indexSet().index(elem); } void computeIntersectionDofs(const typename GV::IndexSet& idxSet, const Intersection& intersection, std::set<size_t>& intersectionDofs) { intersectionDofs.clear(); // loop over all vertices of the intersection const Element& insideElement = intersection.inside(); const auto& refElement = Dune::ReferenceElements<double,dimElement>::general(insideElement.type()); for (int i=0; i<refElement.size(intersection.indexInInside(), 1, dimElement); i++) { size_t idxInElement = refElement.subEntity(intersection.indexInInside(), 1, i, dimElement); size_t globalIdx = idxSet.subIndex(insideElement, idxInElement, dimElement); intersectionDofs.insert(globalIdx); } } void buildVertexDofs(const typename GV::IndexSet& idxSet, const Intersection& intersection, size_t vertex) { const auto& insideElem = intersection.inside(); std::set<size_t> visited; std::queue<Element> subpatchSeeds; subpatchSeeds.push(insideElem); subpatchSeeds.push(intersection.outside()); size_t dofIdx = vertex; bool firstSubpatch = true; while (!subpatchSeeds.empty()) { const Element& subpatchSeed = subpatchSeeds.front(); subpatchSeeds.pop(); if (visited.count(elementIndex(subpatchSeed))) continue; if (firstSubpatch) { firstSubpatch = false; } else { dofIdx = ++maxDofIdx_; } std::queue<Element> elemQueue; elemQueue.push(subpatchSeed); while (!elemQueue.empty()) { const Element& elem = elemQueue.front(); elemQueue.pop(); size_t elemIdx = elementIndex(elem); if (visited.count(elemIdx)) continue; visited.insert(elemIdx); indexMapper_[std::make_pair(elemIdx, vertex)] = dofIdx; // iterate over elem intersections IntersectionIterator it = gridview_.ibegin(elem); IntersectionIterator endIt = gridview_.iend(elem); for (; it != endIt; ++it){ if (it->neighbor()) { const Element& neighbor = it->outside(); std::set<size_t> intersectionDofs; computeIntersectionDofs(idxSet, *it, intersectionDofs); if (intersectionDofs.count(vertex)) { if (faultNetwork_.isInterfaceIntersection(*it)) { subpatchSeeds.push(neighbor); } else { elemQueue.push(neighbor); } } } } } } } public: typedef typename Base::GridView GridView; typedef typename Base::ReturnType ReturnType; typedef typename Dune::PQkLocalFiniteElementCache<ctype, RT, GV::dimension, 1> FiniteElementCache; typedef typename FiniteElementCache::FiniteElementType LFE; typedef typename Base::LocalFiniteElement LocalFiniteElement; typedef typename Base::LinearCombination LinearCombination; FaultP1NodalBasis(const LevelInterfaceNetwork<GV>& faultNetwork) : Base(faultNetwork.levelGridView()), faultNetwork_(faultNetwork), maxDofIdx_(gridview_.indexSet().size(dimGrid)-1) { // compute dof indices for vertices located on the fault network const typename GV::IndexSet& idxSet = gridview_.indexSet(); std::set<size_t> visited; const std::vector<Intersection>& faultIntersections = faultNetwork_.getIntersections(); for (size_t i=0; i<faultIntersections.size(); i++) { const Intersection& faultIntersection = faultIntersections[i]; std::set<size_t> intersectionDofs; computeIntersectionDofs(idxSet, faultIntersection, intersectionDofs); //print(intersectionDofs, "intersectionDofs: "); std::set<size_t>::const_iterator it = intersectionDofs.begin(); std::set<size_t>::const_iterator endIt = intersectionDofs.end(); for (; it!=endIt; ++it){ size_t vertex = *it; if (!visited.count(vertex)) { buildVertexDofs(idxSet, faultIntersection, vertex); visited.insert(vertex); } } } // TODO: remove after debugging /*std::cout << "FaultP1Nodalbasis: indexMapper, size: " << indexMapper_.size() << std::endl; std::map<std::pair<size_t, size_t>, int>::const_iterator mapIt = indexMapper_.begin(); std::map<std::pair<size_t, size_t>, int>::const_iterator endMapIt = indexMapper_.end(); for (; mapIt!=endMapIt; ++mapIt) { std::cout << "elemIdx: " << mapIt->first.first << " , vertexIdx: " << mapIt->first.second << " : " << mapIt->second << std::endl; } std::cout << "FaultP1Nodalbasis: size " << this->size() << std::endl; */ // ---- end remove ------------ // compute prolongation and restriction matrices const std::set<size_t>& interfaceNetworkDofs = faultNetwork_.getInterfaceNetworkDofs(); const size_t conDim = gridview_.indexSet().size(dimGrid); const size_t disconDim = conDim + faultNetwork_.dofCount(); Dune::MatrixIndexSet prolongIdxSet(disconDim, conDim); Dune::MatrixIndexSet restrictIdxSet(conDim, disconDim); // compute respective index sets for (size_t i=0; i<conDim; i++) { prolongIdxSet.add(i, i); restrictIdxSet.add(i, i); } std::set<size_t>::iterator beginIt = interfaceNetworkDofs.begin(); for (size_t i=conDim; i<disconDim; i++) { std::set<size_t>::iterator it = beginIt; std::advance(it, i-conDim); const size_t idx = *it; prolongIdxSet.add(i, idx); restrictIdxSet.add(idx, i); } // set entries of prolongation prolongIdxSet.exportIdx(prolongationMatrix_); prolongationMatrix_ = 1; // set entries of restriction restrictIdxSet.exportIdx(restrictionMatrix_); typedef typename MatrixType::row_type RowType; typedef typename RowType::ConstIterator ColumnIterator; for(size_t rowIdx = 0; rowIdx<conDim; rowIdx++) { RowType& row = restrictionMatrix_[rowIdx]; RT entry = 1.0/row.size(); ColumnIterator colIt = row.begin(); ColumnIterator colEndIt = row.end(); for(; colIt!=colEndIt; ++colIt) { row[colIt.index()] = entry; } } } const LevelInterfaceNetwork<GV>& faultNetwork() const { return faultNetwork_; } size_t size() const { return maxDofIdx_+1; } const LocalFiniteElement& getLocalFiniteElement(const Element& e) const { return cache_.get(e.type()); } int index(const Element& e, const int i) const { const size_t globalIdx = indexInGridView(e, i); const size_t elemIdx = elementIndex(e); const std::pair<size_t, size_t> idxPair = std::make_pair(elemIdx, globalIdx); if (indexMapper_.count(idxPair)) return indexMapper_.at(idxPair); else return globalIdx; } int indexInGridView(const Element& e, const int i) const { return gridview_.indexSet().subIndex(e, getLocalFiniteElement(e).localCoefficients().localKey(i).subEntity(), dimGrid); } // prolong continuous p1 vector to discontinuous p1 vector template<class VectorType> void prolong(const VectorType& x, VectorType& res) { res.resize(prolongationMatrix_.N()); prolongationMatrix_.mv(x, res); } // restrict discontinuous p1 vector to continuous p1 vector template<class VectorType> void restrict(const VectorType& x, VectorType& res) { res.resize(restrictionMatrix_.N()); restrictionMatrix_.mv(x, res); } }; #endif