#ifndef LEVEL_PATCH_PRECONDITIONER_HH
#define LEVEL_PATCH_PRECONDITIONER_HH

#include <string>

#include <dune/common/timer.hh>
#include <dune/common/fvector.hh>
#include <dune/common/bitsetvector.hh>

#include <dune/solvers/iterationsteps/lineariterationstep.hh>
#include <dune/solvers/common/numproc.hh>

#include "../../data-structures/network/levelcontactnetwork.hh"

#include "patchproblem.hh"
#include "supportpatchfactory.hh"

#include <dune/localfunctions/lagrange/pqkfactory.hh>

#include <dune/fufem/boundarypatch.hh>
#include <dune/fufem/functiontools/boundarydofs.hh>

enum MPPMode {additive, multiplicative};
inline
std::istream& operator>>(std::istream& lhs, MPPMode& e)
{
  std::string s;
  lhs >> s;

  if (s == "additive" || s == "1")
    e = MPPMode::additive;
  else if (s == "multiplicative" || s == "0")
    e = MPPMode::multiplicative;
  else
    lhs.setstate(std::ios_base::failbit);

  return lhs;
}

template <class LevelContactNetwork, class PatchSolver, class MatrixType, class VectorType>
class LevelPatchPreconditioner : public LinearIterationStep<MatrixType, VectorType> {

public:
    static const int dim = LevelContactNetwork::dim;

private:
    using Base = LinearIterationStep<MatrixType, VectorType>;

    using ctype = typename LevelContactNetwork::ctype;

    using PatchFactory = SupportPatchFactory<LevelContactNetwork>;
    using Patch = typename PatchFactory::Patch;
    using PatchProblem = PatchProblem<MatrixType, VectorType>;

    const MPPMode mode_;

    const LevelContactNetwork& levelContactNetwork_;
    const LevelContactNetwork& fineContactNetwork_;

    const int level_;

    PatchFactory patchFactory_;
    std::vector<Patch> patches_;
    std::vector<std::unique_ptr<PatchProblem>> patchProblems_;
    std::shared_ptr<PatchSolver> patchSolver_;
    size_t patchDepth_;

public:

    // for each coarse patch given by levelContactNetwork: set up local problem, compute local correction
    LevelPatchPreconditioner(const LevelContactNetwork& levelContactNetwork,
                             const LevelContactNetwork& fineContactNetwork,
                             const MPPMode mode = additive) :
            mode_(mode),
            levelContactNetwork_(levelContactNetwork),
            fineContactNetwork_(fineContactNetwork),
            level_(levelContactNetwork_.level()),
            patchFactory_(levelContactNetwork_, fineContactNetwork_) {

        setPatchDepth();
        this->verbosity_ = NumProc::QUIET;
    }

    // build support patches
    void build() {
        size_t totalNVertices = 0;
        for (size_t i=0; i<levelContactNetwork_.nBodies(); i++) {
            totalNVertices += levelContactNetwork_.body(i)->nVertices();
        }

        patches_.resize(totalNVertices);

        // init local fine level corrections
        Dune::Timer timer;
        if (this->verbosity_ == NumProc::FULL) {
            std::cout << std::endl;
            std::cout << "---------------------------------------------" << std::endl;
            std::cout << "Initializing local fine grid corrections! Level: " << level_ << std::endl;


            timer.reset();
            timer.start();
        }

        Dune::BitSetVector<1> vertexVisited(totalNVertices);
        vertexVisited.unsetAll();

        const auto& levelIndices = patchFactory_.coarseIndices();

        for (size_t bodyIdx=0; bodyIdx<levelContactNetwork_.nBodies(); bodyIdx++) {
            const auto& gridView = levelContactNetwork_.body(bodyIdx)->gridView();

            //printDofLocation(gridView);

            for (const auto& e : elements(gridView)) {
                const auto& refElement = Dune::ReferenceElements<double, dim>::general(e.type());

                for (int i=0; i<refElement.size(dim); i++) {
                    auto globalIdx = levelIndices.vertexIndex(bodyIdx, e, i);

                    if (!vertexVisited[globalIdx][0]) {
                        vertexVisited[globalIdx][0] = true;
                        patchFactory_.build(bodyIdx, e, i, patches_[globalIdx], patchDepth_);
                    }
                }
            }
        }

        if (this->verbosity_ == NumProc::FULL) {
            std::cout << std::endl;
            std::cout << "Total setup time: " << timer.elapsed() << " seconds." << std::endl;
            std::cout << "---------------------------------------------" << std::endl;
            timer.stop();
        }

        if (this->verbosity_ != NumProc::QUIET) {
            std::cout << "Level: " << level_ << " LevelPatchPreconditioner::build() - SUCCESS" << std::endl;
        }
    }

    void setPatchSolver(std::shared_ptr<PatchSolver> patchSolver) {
        patchSolver_ = patchSolver;
    }

    void setPatchDepth(const size_t patchDepth = 0) {
        patchDepth_ = patchDepth;
    }

    void setVerbosity(NumProc::VerbosityMode verbosity) {
        this->verbosity_ = verbosity;
    }

    void setMatrix(const MatrixType& mat) override {
        Base::setMatrix(mat);

        patchProblems_.resize(patches_.size());
        for (size_t i=0; i<patches_.size(); i++) {
            patchProblems_[i] = std::make_unique<PatchProblem>(mat, patches_[i]);
        }

        //std::cout << "matrix set!" << std::endl;
    }

    void iterate() override {
        if (mode_ == additive)
            iterateAdd();
        else
            iterateMult();
    }

    void iterateAdd() {
        *(this->x_) = 0;	
        VectorType x = *(this->x_);

        Dune::Timer timer;
        timer.start();

        size_t systemSize = 0;
        size_t count = 0;

        //std::cout << "level::iterate() ... patches: " << patches_.size() << " level size: " << x.size() << std::endl;

        for (size_t i=0; i<patches_.size(); i++) {
            x = 0;

            /*
            const auto& p = patches_[i];

            auto ignore = this->ignore();
            for (size_t i=0; i<ignore.size(); i++) {
                for (size_t d=0; d<dim; d++) {
                   if (p[i][d])
                       ignore[i][d] = true;
                }
            }

            auto& step = patchSolver_->getIterationStep();
            dynamic_cast<LinearIterationStep<MatrixType, VectorType>&>(step).setProblem(*this->mat_, x, *this->rhs_);
            step.setIgnore(ignore);*/

            const auto& patchMat = patchProblems_[i]->mat();

            patchProblems_[i]->setRhs(*this->rhs_);
            const auto& patchRhs = patchProblems_[i]->rhs();

            VectorType patchX(patchMat.M());
            patchX = 0;

            auto& step = patchSolver_->getIterationStep();
            dynamic_cast<LinearIterationStep<MatrixType, VectorType>&>(step).setProblem(patchMat, patchX, patchRhs);

            // empty ignore
            Dune::Solvers::DefaultBitVector_t<VectorType> ignore(patchX.size());
            ignore.unsetAll();
            step.setIgnore(ignore);

            patchSolver_->check();
            patchSolver_->preprocess();
            patchSolver_->solve();

            patchProblems_[i]->prolong(patchX, x);

            *(this->x_) += x;

            /*if (count*1.0/patches_.size() >= 0.1) {
                std::cout << (int) (i*1.0/patches_.size()*100) << " %. Elapsed time: " << timer.elapsed() << std::endl;
                count = 0;
            }
            count++;
            systemSize += patchX.size();*/
        }

       /* timer.stop();

        std::cout << "Total elapsed time: " << timer.elapsed() << std::endl;
        std::cout << "Average time per patch: " << timer.elapsed()*1.0/patches_.size() << std::endl;
        std::cout << "Average patch size: " << systemSize*1.0/patches_.size() << std::endl;
        std::cout << "-------------------------------" << std::endl << std::endl;*/
    }

    void iterateMult() {
        *(this->x_) = 0;

        DUNE_THROW(Dune::Exception, "Not implemented!");

/*
        VectorType x = *(this->x_);
        for (size_t i=0; i<patches_.size(); i++) {
            VectorType updatedRhs(*(this->rhs_));
            this->mat_->mmv(*(this->x_), updatedRhs);

            patchSolver_.setProblem(*this->mat_, *(this->x_), updatedRhs);
            patchSolver_.setIgnore(patches_[i]);
            patchSolver_.solve();

            *(this->x_) += x;
        } */
    }

    size_t size() const {
        return patches_.size();
    }

    size_t level() const {
        return level_;
    }

    size_t fineLevel() const {
        return fineContactNetwork_.level();
    }
};

#endif