#ifndef SRC_SPATIAL_SOLVING_PRECONDITIONERS_PATCH_PROBLEM_HH
#define SRC_SPATIAL_SOLVING_PRECONDITIONERS_PATCH_PROBLEM_HH

#include <math.h>   
#include <dune/common/fmatrix.hh>
#include <dune/common/function.hh>
#include <dune/common/timer.hh>

#include <dune/istl/matrixindexset.hh>
//#include <dune/istl/superlu.hh>
#include <dune/istl/umfpack.hh>

#include <dune/fufem/assemblers/localoperatorassembler.hh>

#include "../../utils/debugutils.hh"

template <class MatrixType, class DomainType, class RangeType = DomainType>
class PatchProblem {
  
private:    
    const static size_t dim = DomainType::block_type::dimension;

    using BitVector = Dune::BitSetVector<dim>;

    const MatrixType& mat_;

    std::vector<size_t> localToGlobal_;

    MatrixType localMat_;
    RangeType localRhs_;


public:
    PatchProblem(const MatrixType& mat, const Dune::BitSetVector<1>& patch) :
        mat_(mat) {
	  
        // construct localToGlobal map
        localToGlobal_.clear();
        for (size_t i=0; i<patch.size(); ++i) {
            if (!patch[i][0]) {
                localToGlobal_.push_back(i);
            }
        }

        // build local matrix
        auto localDim = localToGlobal_.size();
        Dune::MatrixIndexSet localIdxSet(localDim, localDim);

        for(size_t rowIdx=0; rowIdx<localDim; rowIdx++) {
            const auto globalRowIdx = localToGlobal_[rowIdx];
            const auto& row = mat_[globalRowIdx];

            const auto cEndIt = row.end();
            for(auto cIt=row.begin(); cIt!=cEndIt; ++cIt) {
                const auto globalColIdx = cIt.index();

                auto localColIdx = std::find(localToGlobal_.begin(), localToGlobal_.end(), globalColIdx);
                if (localColIdx!=localToGlobal_.end()) {
                    localIdxSet.add(rowIdx, localColIdx-localToGlobal_.begin());
                }
            }
        }

        localIdxSet.exportIdx(localMat_);

        for(size_t rowIdx=0; rowIdx<localMat_.N(); rowIdx++) {
            auto& row = localMat_[rowIdx];
            const auto& globalRow = mat_[localToGlobal_[rowIdx]];


            const auto cEndIt = row.end();
            for(auto cIt=row.begin(); cIt!=cEndIt; ++cIt) {
                row[cIt.index()] = globalRow[localToGlobal_[cIt.index()]];
            }
        }

        // init local rhs
        localRhs_.resize(localDim);
        localRhs_ = 0;
    }

    const MatrixType& mat() {
        return localMat_;
    }
    
    const RangeType& rhs() {
        return localRhs_;
    }

    void setRhs(const RangeType& rhs){
        for (size_t i=0; i<localRhs_.size(); i++) {
            localRhs_[i] = rhs[localToGlobal_[i]];
        }
    }

    void prolong(const DomainType& x, DomainType& res){
        res.resize(mat_.N());
        res = 0;

        for (size_t i=0; i<x.size(); i++) {
                res[localToGlobal_[i]] = x[i];
        }
    }

    void restrict(const RangeType& x, RangeType& res){
        res.resize(localToGlobal_.size());
        res = 0;

        for (size_t i=0; i<res.size(); i++) {
            res[i] = x[localToGlobal_[i]];
        }
    }
};

#endif