#ifndef LEVEL_PATCH_PRECONDITIONER_HH
#define LEVEL_PATCH_PRECONDITIONER_HH

#include <string>

#include <dune/common/timer.hh>
#include <dune/common/fvector.hh>
#include <dune/common/bitsetvector.hh>

#include <dune/solvers/iterationsteps/lineariterationstep.hh>
#include <dune/solvers/common/numproc.hh>

#include "../../data-structures/levelcontactnetwork.hh"

#include "supportpatchfactory.hh"
#include "../tnnmg/localproblem.hh"

#include <dune/localfunctions/lagrange/pqkfactory.hh>

#include <dune/fufem/boundarypatch.hh>
#include <dune/fufem/functiontools/boundarydofs.hh>

enum MPPMode {additive, multiplicative};
inline
std::istream& operator>>(std::istream& lhs, MPPMode& e)
{
  std::string s;
  lhs >> s;

  if (s == "additive" || s == "1")
    e = MPPMode::additive;
  else if (s == "multiplicative" || s == "0")
    e = MPPMode::multiplicative;
  else
    lhs.setstate(std::ios_base::failbit);

  return lhs;
}

template <class LevelContactNetwork, class PatchSolver, class MatrixType, class VectorType>
class LevelPatchPreconditioner : public LinearIterationStep<MatrixType, VectorType> {

public:
    static const int dim = LevelContactNetwork::dim;

private:
    using Base = LinearIterationStep<MatrixType, VectorType>;

    using ctype = typename LevelContactNetwork::ctype;

    using PatchFactory = SupportPatchFactory<LevelContactNetwork>;
    using Patch = typename PatchFactory::Patch;

    const MPPMode mode_;

    const LevelContactNetwork& levelContactNetwork_;
    const LevelContactNetwork& fineContactNetwork_;

    const int level_;

    PatchFactory patchFactory_;
    std::vector<Patch> patches_;

    std::shared_ptr<PatchSolver> patchSolver_;
    size_t patchDepth_;

public:

    // for each coarse patch given by levelContactNetwork: set up local problem, compute local correction
    LevelPatchPreconditioner(const LevelContactNetwork& levelContactNetwork,
                             const LevelContactNetwork& fineContactNetwork,
                             const MPPMode mode = additive) :
            mode_(mode),
            levelContactNetwork_(levelContactNetwork),
            fineContactNetwork_(fineContactNetwork),
            level_(levelContactNetwork_.level()),
            patchFactory_(levelContactNetwork_, fineContactNetwork_) {

        setPatchDepth();
        this->verbosity_ = NumProc::QUIET;
    }

    // build support patches
    void build() {
        size_t totalNVertices = 0;
        for (size_t i=0; i<levelContactNetwork_.nBodies(); i++) {
            totalNVertices += levelContactNetwork_.body(i)->nVertices();
        }

        patches_.resize(totalNVertices);

        // init local fine level corrections
        Dune::Timer timer;
        if (this->verbosity_ == NumProc::FULL) {
            std::cout << std::endl;
            std::cout << "---------------------------------------------" << std::endl;
            std::cout << "Initializing local fine grid corrections! Level: " << level_ << std::endl;


            timer.reset();
            timer.start();
        }

        Dune::BitSetVector<1> vertexVisited(totalNVertices);
        vertexVisited.unsetAll();

        const auto& levelIndices = patchFactory_.coarseIndices();

        for (size_t bodyIdx=0; bodyIdx<levelContactNetwork_.nBodies(); bodyIdx++) {
            const auto& gridView = levelContactNetwork_.body(bodyIdx)->gridView();

            //printDofLocation(gridView);

            for (const auto& e : elements(gridView)) {
                const auto& refElement = Dune::ReferenceElements<double, dim>::general(e.type());

                for (size_t i=0; i<refElement.size(dim); i++) {
                    auto globalIdx = levelIndices.vertexIndex(bodyIdx, e, i);

                    if (!vertexVisited[globalIdx][0]) {
                        vertexVisited[globalIdx][0] = true;
                        patchFactory_.build(bodyIdx, e, i, patches_[globalIdx], patchDepth_);

                        /*print(patches_[globalIdx], "patch:");

                        size_t c = 0;
                        for (size_t j=0; j<levelContactNetwork_.nBodies(); j++) {
                            const auto& gv = fineContactNetwork_.body(j)->gridView();

                            printDofLocation(gv);

                            Dune::BlockVector<Dune::FieldVector<ctype, 1>> patchVec(gv.size(dim));
                            for (size_t l=0; l<patchVec.size(); l++) {
                                if (patches_[globalIdx][c++][0]) {
                                    patchVec[l][0] = 1;
                                }
                            }

                            //print(patchVec, "patchVec");

                            // output patch
                            writeToVTK(gv, patchVec, "", "level_" + std::to_string(level_) + "_patch_" + std::to_string(globalIdx) + "_body_" + std::to_string(j));
                        }*/
                    }
                }
            }
        }

        if (this->verbosity_ == NumProc::FULL) {
            std::cout << std::endl;
            std::cout << "Total setup time: " << timer.elapsed() << " seconds." << std::endl;
            std::cout << "---------------------------------------------" << std::endl;
            timer.stop();
        }

        if (this->verbosity_ != NumProc::QUIET) {
            std::cout << "Level: " << level_ << " LevelPatchPreconditioner::build() - SUCCESS" << std::endl;
        }
    }

    void setPatchSolver(std::shared_ptr<PatchSolver> patchSolver) {
        patchSolver_ = patchSolver;
    }

    void setPatchDepth(const size_t patchDepth = 0) {
        patchDepth_ = patchDepth;
    }

    void setVerbosity(NumProc::VerbosityMode verbosity) {
        this->verbosity_ = verbosity;
    }

    virtual void iterate() {
        if (mode_ == additive)
            iterateAdd();
        else
            iterateMult();
    }

    void iterateAdd() {
        *(this->x_) = 0;	
        VectorType x = *(this->x_);

        for (const auto& p : patches_) {
            x = 0;

            auto ignore = this->ignore();
            for (size_t i=0; i<ignore.size(); i++) {
                for (size_t d=0; d<dim; d++) {
                   if (p[i][d])
                       ignore[i][d] = true;
                }
            }

            auto& step = patchSolver_->getIterationStep();
            dynamic_cast<LinearIterationStep<MatrixType, VectorType>&>(step).setProblem(*this->mat_, x, *this->rhs_);
            step.setIgnore(ignore);

            patchSolver_->check();
            patchSolver_->preprocess();
            patchSolver_->solve();

            *(this->x_) += x;



            /*LocalProblem<MatrixType, VectorType> localProblem(*this->mat_, *this->rhs_, ignore);
            Vector newR;
            localProblem.getLocalRhs(x, newR); */

            /*print(ignore, "ignore:");
            print(*this->mat_, "Mat:");
            print(localProblem.getMat(), "localMat:");*/

            /*patchSolver_->setProblem(localProblem.getMat(), x, newR);
            patchSolver_->preprocess();
            patchSolver_->solve();

            *(this->x_) += x;*/
        }
    }

    void iterateMult() {
        *(this->x_) = 0;
/*
        VectorType x = *(this->x_);
        for (size_t i=0; i<patches_.size(); i++) {
            VectorType updatedRhs(*(this->rhs_));
            this->mat_->mmv(*(this->x_), updatedRhs);

            patchSolver_.setProblem(*this->mat_, *(this->x_), updatedRhs);
            patchSolver_.setIgnore(patches_[i]);
            patchSolver_.solve();

            *(this->x_) += x;
        } */
    }

    size_t size() const {
        return patches_.size();
    }

    size_t level() const {
        return level_;
    }

    size_t fineLevel() const {
        return fineContactNetwork_.level();
    }
};

#endif