#ifndef DUNE_TECTONIC_SPATIAL_SOLVING_SOLVENONLINEARPROBLEM_HH
#define DUNE_TECTONIC_SPATIAL_SOLVING_SOLVENONLINEARPROBLEM_HH

#include <dune/common/exceptions.hh>

#include <dune/matrix-vector/axpy.hh>

#include <dune/solvers/norms/energynorm.hh>
#include <dune/solvers/solvers/loopsolver.hh>

#include <dune/contact/assemblers/nbodyassembler.hh>
#include <dune/contact/common/dualbasisadapter.hh>

#include <dune/localfunctions/lagrange/pqkfactory.hh>

#include <dune/functions/gridfunctions/gridfunction.hh>

#include <dune/geometry/quadraturerules.hh>
#include <dune/geometry/type.hh>
#include <dune/geometry/referenceelements.hh>

#include <dune/fufem/functions/basisgridfunction.hh>

#include "../data-structures/enums.hh"
#include "../data-structures/enumparser.hh"



#include "../utils/tobool.hh"
#include "../utils/debugutils.hh"

#include <dune/solvers/solvers/loopsolver.hh>
#include <dune/solvers/iterationsteps/truncatedblockgsstep.hh>
#include <dune/solvers/iterationsteps/multigridstep.hh>

#include "../spatial-solving/contact/nbodycontacttransfer.hh"

#include "solverfactory.hh"
#include "solverfactory.cc"

#include <dune/tectonic/utils/reductionfactors.hh>

#include "fixedpointiterator.hh"


template <class Functional, class BitVector>
class NonlinearSolver {
protected:
    using Matrix = typename Functional::Matrix;
    using Vector = typename Functional::Vector;
    using Factory = SolverFactory<Functional, BitVector>;

    using Norm = EnergyNorm<Matrix, Vector>;
    using SolverType = Dune::Solvers::LoopSolver<Vector>;

    const static int dims = Vector::block_type::dimension;

public:
    template <class LinearSolver>
    NonlinearSolver(const Dune::ParameterTree& tnnmgParset, std::shared_ptr<LinearSolver> linearSolver, std::shared_ptr<Functional> functional, const BitVector& dirichletNodes) :
        functional_(functional),
        norm_(functional_->quadraticPart()) {

        // set up TNMMG step
        solverFactory_ = std::make_shared<Factory>(tnnmgParset, functional_, dirichletNodes);
        solverFactory_->build(linearSolver);
    }

    auto solve(const Dune::ParameterTree& solverParset, Vector& x) {
        auto tnnmgStep = solverFactory_->step();

        SolverType solver(*tnnmgStep.get(), solverParset.get<size_t>("maximumIterations"),
            solverParset.get<double>("tolerance"), norm_,
            solverParset.get<Solver::VerbosityMode>("verbosity")); // absolute error

        const auto& lower = functional_->lowerObstacle();
        const auto& upper = functional_->upperObstacle();

        // project in onto admissible set
        for (size_t i=0; i<x.size(); i++) {
            for (size_t j=0; j<dims; j++) {
                if (x[i][j] < lower[i][j]) {
                    x[i][j] = lower[i][j];
                }

                if (x[i][j] > upper[i][j]) {
                    x[i][j] = upper[i][j];
                }
            }
        }

        solverFactory_->setProblem(x);

        solver.preprocess();
        solver.solve();

        return solver.getResult();
    }

private:
    std::shared_ptr<Functional> functional_;
    std::shared_ptr<Factory> solverFactory_;

    Norm norm_;
};

#endif