#ifndef DUNE_TECTONIC_TIME_STEPPING_STEP_HH
#define DUNE_TECTONIC_TIME_STEPPING_STEP_HH


#include <future>
#include <thread>
#include <chrono>
#include <functional>

#include <dune/solvers/norms/energynorm.hh>
#include <dune/solvers/iterationsteps/multigridstep.hh>
#include <dune/solvers/iterationsteps/cgstep.hh>
#include <dune/solvers/solvers/loopsolver.hh>

#include "../spatial-solving/makelinearsolver.hh"

#include <dune/tectonic/utils/reductionfactors.hh>

#include "stepbase.hh"
#include "adaptivetimestepper.hh"

template<class IterationStepType, class NormType, class ReductionFactorContainer>
Dune::Solvers::Criterion reductionFactorCriterion(
      IterationStepType& iterationStep,
      const NormType& norm,
      ReductionFactorContainer& reductionFactors)
{
  double normOfOldCorrection = 1;
  auto lastIterate = std::make_shared<typename IterationStepType::Vector>(*iterationStep.getIterate());

  return Dune::Solvers::Criterion(
      [&, lastIterate, normOfOldCorrection] () mutable {
        double normOfCorrection = norm.diff(*lastIterate, *iterationStep.getIterate());
        double convRate = (normOfOldCorrection > 0) ? normOfCorrection / normOfOldCorrection : 0.0;

        if (convRate>1.0)
            std::cout << "Solver convergence rate of " << convRate << std::endl;

        normOfOldCorrection = normOfCorrection;
        *lastIterate = *iterationStep.getIterate();

        reductionFactors.push_back(convRate);
        return std::make_tuple(convRate < 0, Dune::formatString(" % '.5f", convRate));
      },
      " reductionFactor");
}


template<class IterationStepType, class Functional, class ReductionFactorContainer>
Dune::Solvers::Criterion energyCriterion(
      const IterationStepType& iterationStep,
      const Functional& f,
      ReductionFactorContainer& reductionFactors)
{
  double normOfOldCorrection = 1;
  auto lastIterate = std::make_shared<typename IterationStepType::Vector>(*iterationStep.getIterate());

  return Dune::Solvers::Criterion(
      [&, lastIterate, normOfOldCorrection] () mutable {
        double normOfCorrection = std::abs(f(*lastIterate) - f(*iterationStep.getIterate())); //norm.diff(*lastIterate, *iterationStep.getIterate());

        double convRate = (normOfOldCorrection != 0.0) ? 1.0 - (normOfCorrection / normOfOldCorrection) : 0.0;

        if (convRate>1.0)
            std::cout << "Solver convergence rate of " << convRate << std::endl;

        normOfOldCorrection = normOfCorrection;
        *lastIterate = *iterationStep.getIterate();

        reductionFactors.push_back(convRate);
        return std::make_tuple(convRate < 0, Dune::formatString(" % '.5f", convRate));
      },
      " reductionFactor");
}

template <class ReductionFactorContainer>
void updateReductionFactors(ReductionFactorContainer& reductionFactors) {
    const size_t s = reductionFactors.size();

    //print(reductionFactors, "reduction factors: ");

    if (s>allReductionFactors.size()) {
        allReductionFactors.resize(s);
    }

    for (size_t i=0; i<reductionFactors.size(); i++) {
        allReductionFactors[i].push_back(reductionFactors[i]);
    }

    reductionFactors.clear();
}


template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
class Step : protected StepBase<Factory, ContactNetwork, Updaters, ErrorNorms> {
public:
    enum Mode { sameThread, newThread };

private:
    using Base = StepBase<Factory, ContactNetwork, Updaters, ErrorNorms>;
    using NBodyAssembler = typename ContactNetwork::NBodyAssembler;
    using UpdatersWithCount = UpdatersWithCount<Updaters>;

    const Updaters& oldUpdaters_;
    const NBodyAssembler& nBodyAssembler_;

    double relativeTime_;
    double relativeTau_;

    IterationRegister& iterationRegister_;

    std::packaged_task<UpdatersWithCount()> task_;
    std::future<UpdatersWithCount> future_;
    std::thread thread_;

    Mode mode_;

public:
    Step(const Base& stepFactory, const Updaters& oldUpdaters, const NBodyAssembler& nBodyAssembler, double rTime, double rTau, IterationRegister& iterationRegister) :
        Base(stepFactory.parset_, stepFactory.contactNetwork_, stepFactory.ignoreNodes_, stepFactory.globalFriction_,
             stepFactory.bodywiseNonmortarBoundaries_, stepFactory.externalForces_, stepFactory.errorNorms_),
        oldUpdaters_(oldUpdaters),
        nBodyAssembler_(nBodyAssembler),
        relativeTime_(rTime),
        relativeTau_(rTau),
        iterationRegister_(iterationRegister) {}

    UpdatersWithCount doStep() {
        // make linear solver for linear correction in TNNMGStep
        using Vector = typename Factory::Vector;
        using Matrix = typename Factory::Matrix;

        auto linearSolver = makeLinearSolver<ContactNetwork, Vector>(this->parset_, this->contactNetwork_);

        Vector x;
        x.resize(nBodyAssembler_.totalHasObstacle_.size());
        x = 0;

        linearSolver->getIterationStep().setProblem(x);
      //dynamic_cast<Dune::Solvers::IterationStep<Vector>*>(linearMultigridStep.get())->setProblem(x);
      //dynamic_cast<Dune::Solvers::IterationStep<Vector>*>(cgStep.get())->setProblem(x);

      //std::vector<double> reductionFactors;
      //linearSolver->addCriterion(reductionFactorCriterion(linearSolver->getIterationStep(), linearSolver->getErrorNorm(), reductionFactors));
      //linearSolver->addCriterion(reductionFactorCriterion(*cgStep, norm, reductionFactors));

      UpdatersWithCount newUpdatersAndCount = {oldUpdaters_.clone(), {}};

      typename Base::MyCoupledTimeStepper coupledTimeStepper(
        this->finalTime_, this->parset_, nBodyAssembler_,
        this->ignoreNodes_, this->globalFriction_, this->bodywiseNonmortarBoundaries_,
        newUpdatersAndCount.updaters, this->errorNorms_, this->externalForces_);

      newUpdatersAndCount.count = coupledTimeStepper.step(linearSolver, relativeTime_, relativeTau_);
      iterationRegister_.registerCount(newUpdatersAndCount.count);

      //updateReductionFactors(reductionFactors);

      return newUpdatersAndCount;
    }

   /* auto simple = [&]{
        std::cout << "starting task... " << std::endl;
        UpdatersWithCount newUpdatersAndCount = {oldUpdaters.clone(), {}};
        std::this_thread::sleep_for(std::chrono::milliseconds(10000));
        std::cout << "finishing task... " << std::endl;
        return newUpdatersAndCount;
    }*/

    void run(Mode mode = sameThread) {
        mode_ = mode;
        task_ = std::packaged_task<UpdatersWithCount()>( [this]{ return doStep(); });
        future_ = task_.get_future();

        if (mode == Mode::sameThread) {
            task_();
        } else {
            thread_ = std::thread(std::move(task_));
        }
    }

    auto get() {
        //std::cout << "Step::get() called" << std::endl;
        if (thread_.joinable()) {
            thread_.join();
            //std::cout << "Thread joined " << std::endl;
        }
        return future_.get();
    }

    ~Step() {
        if (thread_.joinable()) {
            thread_.join();
            std::cout << "Destructor:: Thread joined " << std::endl;
        }
    }
};

template <class Factory, class ContactNetwork, class Updaters, class ErrorNorms>
Step<Factory, ContactNetwork, Updaters, ErrorNorms>
buildStep(const StepBase<Factory, ContactNetwork, Updaters, ErrorNorms>& base,
          const Updaters& oldUpdaters,
          const typename ContactNetwork::NBodyAssembler& nBodyAssembler,
          double rTime, double rTau) {

    return Step(base, oldUpdaters, nBodyAssembler, rTime, rTau);
}

#endif