#ifndef LEVEL_GLOBAL_PRECONDITIONER_HH
#define LEVEL_GLOBAL_PRECONDITIONER_HH

#include <string>

#include <dune/common/fvector.hh>

#include <dune/solvers/iterationsteps/lineariterationstep.hh>

#include <dune/istl/umfpack.hh>

#include <dune/fufem/boundarypatch.hh>
#include <dune/fufem/functiontools/boundarydofs.hh>

#include <dune/faultnetworks/assemblers/globalfaultassembler.hh>
#include <dune/faultnetworks/levelinterfacenetwork.hh>
#include <dune/faultnetworks/utils/debugutils.hh>


template <class BasisType, class LocalAssembler, class LocalInterfaceAssembler, class MatrixType, class VectorType>
class LevelGlobalPreconditioner : public LinearIterationStep<MatrixType, VectorType> {

public:
    enum BoundaryMode {homogeneous, fromIterate};

private:
    typedef typename BasisType::GridView GridView;
    typedef typename GridView::Grid GridType;

    const LevelInterfaceNetwork<GridView>& levelInterfaceNetwork_;
    const LocalAssembler& localAssembler_;
    const std::vector<std::shared_ptr<LocalInterfaceAssembler>>& localInterfaceAssemblers_;

    const GridType& grid_;
    const int level_;
    const BasisType basis_;

    Dune::BitSetVector<1> boundaryDofs_;
    MatrixType matrix_;
    VectorType rhs_;

    bool requireBuild_;

public:

    // for each active fault in levelInterfaceNetwork: set up local problem, compute local correction
    LevelGlobalPreconditioner(const LevelInterfaceNetwork<GridView>& levelInterfaceNetwork,
                             const LocalAssembler& localAssembler,
                             const std::vector<std::shared_ptr<LocalInterfaceAssembler>>& localInterfaceAssemblers) :
          levelInterfaceNetwork_(levelInterfaceNetwork),
          localAssembler_(localAssembler),
          localInterfaceAssemblers_(localInterfaceAssemblers),
          grid_(levelInterfaceNetwork_.grid()),
          level_(levelInterfaceNetwork_.level()),
          basis_(levelInterfaceNetwork_)
    {
        assert(localInterfaceAssemblers_.size() == levelInterfaceNetwork_.size());
    }

    void build() {
        //printBasisDofLocation(basis_);

        GlobalFaultAssembler<BasisType, BasisType> globalFaultAssembler(basis_, basis_, levelInterfaceNetwork_);
        globalFaultAssembler.assembleOperator(localAssembler_, localInterfaceAssemblers_, matrix_);

        typedef typename MatrixType::row_type RowType;
        typedef typename RowType::ConstIterator ColumnIterator;

        const GridView& gridView = levelInterfaceNetwork_.levelGridView();

        // set boundary conditions
        BoundaryPatch<GridView> boundaryPatch(gridView, true);
        constructBoundaryDofs(boundaryPatch, basis_, boundaryDofs_);

        for(size_t i=0; i<boundaryDofs_.size(); i++) {
            if(!boundaryDofs_[i][0])
                continue;

            RowType& row = matrix_[i];

            ColumnIterator cIt    = row.begin();
            ColumnIterator cEndIt = row.end();

            for(; cIt!=cEndIt; ++cIt) {
                row[cIt.index()] = 0;
            }

            row[i] = 1;
        }

        requireBuild_ = false;
    }



    virtual void setProblem(const MatrixType& mat, VectorType& x, const VectorType& rhs) {
        this->x_ = &x;
        updateRhs(rhs);
        this->mat_ = Dune::stackobject_to_shared_ptr(mat); // mat will be set but not used
    }

    virtual void iterate() {      
        //print(matrix_, "coarse matrix: ");
        //print(rhs_, "coarse rhs: ");

        // compute solution directly

        if (requireBuild_)
            DUNE_THROW(Dune::Exception, "LevelGlobalPreconditioner::iterate() Call build() before solving the global problem!");

        Dune::Timer timer;
        timer.reset();
        timer.start();

        this->x_->resize(matrix_.M());
        *(this->x_) = 0;

        #if HAVE_UMFPACK
        Dune::InverseOperatorResult res;
        VectorType rhsCopy(rhs_);

        Dune::UMFPack<MatrixType> solver(matrix_);
        solver.apply(*(this->x_), rhsCopy, res);

        #else
        #error No UMFPack!
        #endif

        //std::cout << "LevelGlobalPreconditioner::iterate() Solving global problem took: " << timer.elapsed() << " seconds" << std::endl;

    }

    const BasisType& basis() const {
        return basis_;
    }

    const GridView& gridView() const {
        return basis_.getGridView();
    }

    const LevelInterfaceNetwork<GridView>& levelInterfaceNetwork() const {
        return levelInterfaceNetwork_;
    }

private:
    void updateRhs(const VectorType& rhs){
        rhs_.resize(matrix_.M());
        rhs_ = 0;

        for (size_t i=0; i<rhs.size(); i++) {
            if (!boundaryDofs_[i][0])
                rhs_[i] = rhs[i];
        }
    }
};

#endif