#include <dune/common/fvector.hh>
#include <dune/common/exceptions.hh>
#include <dune/istl/bdmatrix.hh>
#include <dune/istl/matrixindexset.hh>

#include <dune/fufem/assemblers/transferoperatorassembler.hh>
#include <dune/matrix-vector/blockmatrixview.hh>

#include <dune/matrix-vector/addtodiagonal.hh>
#include <dune/matrix-vector/axpy.hh>
#include <dune/solvers/common/numproc.hh>
#include <dune/solvers/transferoperators/densemultigridtransfer.hh>

template<class ContactNetwork, class VectorType>
void NBodyContactTransfer<ContactNetwork, VectorType>::setup(const ContactNetwork& contactNetwork, const size_t coarseLevel, const size_t fineLevel) {
    const size_t nBodies     = contactNetwork.nBodies();
    const size_t nCouplings = contactNetwork.nCouplings();

    const auto& coarseContactNetwork = *contactNetwork.level(coarseLevel);
    const auto& fineContactNetwork = *contactNetwork.level(fineLevel);

    std::vector<size_t> maxL(nBodies), cL(nBodies), fL(nBodies);

    for (size_t i=0; i<nBodies; i++) {
        maxL[i] = contactNetwork.body(i)->grid()->maxLevel();
        cL[i] = coarseContactNetwork.body(i)->level();
        fL[i] = fineContactNetwork.body(i)->level();
    }

    // ////////////////////////////////////////////////////////////
    //   Create the standard prolongation matrices for each grid
    // ////////////////////////////////////////////////////////////

    std::vector<TruncatedDenseMGTransfer<VectorType>* > gridTransfer(nBodies);
    std::vector<const MatrixType*> submat(nBodies);

    for (size_t i=0; i<nBodies; i++) {
        if (fL[i] > cL[i]) {
            gridTransfer[i] = new TruncatedDenseMGTransfer<VectorType>;
            gridTransfer[i]->setup(*contactNetwork.body(i)->grid(), cL[i], fL[i]);
            submat[i] = &gridTransfer[i]->getMatrix();

        } else {
            // gridTransfer is identity if coarse and fine level coincide for body
            Dune::BDMatrix<MatrixBlock>* newMatrix = new Dune::BDMatrix<MatrixBlock>(coarseContactNetwork.body(i)->nVertices());

            for (size_t j=0; j<newMatrix->N(); j++)
                for (int k=0; k<blocksize; k++)
                    for (int l=0; l<blocksize; l++)
                        (*newMatrix)[j][j][k][l] = (k==l);

            submat[i] = newMatrix;
        }
    }

    // ///////////////////////////////////
    //   Combine the submatrices
    // ///////////////////////////////////

    /*
            (const std::vector<const GridType*>& grids, int colevel,
                                                             const std::vector<const MatrixType*>& mortarTransferOperator,
                                                             const CoordSystemVector& fineLocalCoordSystems,
                                                             const std::vector<const BitSetVector<1>*>& fineHasObstacle,
                                                             const std::vector<std::array<int,2> >& gridIdx)

    transfer->setup(*grids[0], *grids[1], i,
                            contactAssembler.contactCoupling_[0]->mortarLagrangeMatrix(),
                            contactAssembler.localCoordSystems_,
                            *);

    */

    if (fineLevel == contactNetwork.nLevels()-1) {
        const auto& nBodyAssembler = contactNetwork.nBodyAssembler();
        const auto& contactCouplings = nBodyAssembler.getContactCouplings();

        std::vector<const MatrixType*> mortarTransferOperators(nCouplings);
        std::vector<const Dune::BitSetVector<1>*> fineHasObstacle(nCouplings);
        std::vector<std::array<int,2> > gridIdx(nCouplings);

        for (size_t i=0; i<nCouplings; i++) {
            mortarTransferOperators[i] = &contactCouplings[i]->mortarLagrangeMatrix();
            fineHasObstacle[i] = contactCouplings[i]->hasObstacle();
            gridIdx[i] = nBodyAssembler.getCoupling(i).gridIdx_;
        }

        combineSubMatrices(submat, mortarTransferOperators, nBodyAssembler.getLocalCoordSystems(), fineHasObstacle, gridIdx);
    } else {
        Dune::MatrixVector::BlockMatrixView<MatrixType>::setupBlockMatrix(submat, this->matrix_);
    }

    for (size_t i=0; i<nBodies; i++) {
        if (fL[i] <= cL[i]) {
            delete(submat[i]);
        }
        delete(gridTransfer[i]);
    }
}

/*
template<class VectorType>
template<class GridType>
void NBodyContactTransfer<VectorType>::setupHierarchy(std::vector<std::shared_ptr<NonSmoothNewtonContactTransfer<VectorType> > >& mgTransfers,
                                    const std::vector<const GridType*> grids,
                                    const std::vector<const MatrixType*> mortarTransferOperator,
                                    const CoordSystemVector& fineLocalCoordSystems,
                                    const std::vector<const BitSetVector<1>*>& fineHasObstacle,
                                    const std::vector<std::array<int,2> >& gridIdx)
{
    const size_t nGrids     = grids.size();

    std::vector<std::vector<TruncatedDenseMGTransfer<VectorType>* > > gridTransfer(nGrids);
    std::vector<const MatrixType*> submat(nGrids);

    // ////////////////////////////////////////////////////////////
    //   Create the standard prolongation matrices for each grid
    // ////////////////////////////////////////////////////////////

    for (size_t i=0; i<nGrids; i++) {

        gridTransfer[i].resize(grids[i]->maxLevel());
        for (size_t j=0; j<gridTransfer[i].size(); j++)
            gridTransfer[i][j] = new TruncatedDenseMGTransfer<VectorType>;

        // Assemble standard transfer operator
        TransferOperatorAssembler<GridType> transferOperatorAssembler(*grids[i]);
        transferOperatorAssembler.assembleOperatorPointerHierarchy(gridTransfer[i]);
    }

    // ////////////////////////////////////////////////////////////////////////////
    //      Combine matrices in one matrix and add mortar entries on the fine level
    // ///////////////////////////////////////////////////////////////////////////
    int toplevel = mgTransfers.size();

    for (size_t colevel=0; colevel<mgTransfers.size(); colevel++) {

        std::vector<int> fL(nGrids);

        for (size_t i=0; i<nGrids; i++) {
            fL[i] = std::max(size_t(0), grids[i]->maxLevel() - colevel);

            // If the prolongation matrix exists, take it
            if (fL[i] > 0) {

                submat[i] =&gridTransfer[i][fL[i]-1]->getMatrix();
            } else {
                // when the maxLevels of the grids differ then we add copys of the coarsest level to the "smaller" grid
                BDMatrix<MatrixBlock>* newMatrix = new BDMatrix<MatrixBlock>(grids[i]->size(0,GridType::dimension));
                *newMatrix = 0;

                for (size_t j=0; j<newMatrix->N(); j++)
                    for (int k=0; k<blocksize; k++)
                            (*newMatrix)[j][j][k][k] = 1.0;

                submat[i] = newMatrix;
            }
        }

        if (colevel == 0)
            mgTransfers[toplevel-colevel-1]->combineSubMatrices(submat, mortarTransferOperator, fineLocalCoordSystems, fineHasObstacle, gridIdx);
        else
          Dune::MatrixVector::BlockMatrixView<MatrixType>::setupBlockMatrix(submat, mgTransfers[toplevel-colevel-1]->matrix_);

        for (size_t i=0; i<nGrids; i++)
            if (fL[i]==0)
                delete(submat[i]);
    }

    // Delete separate transfer objects
    for (size_t i=0; i<nGrids; i++)
        for (size_t j=0; j<gridTransfer[i].size(); j++)
            delete(gridTransfer[i][j]);
}
*/

template<class ContactNetwork, class VectorType>
void NBodyContactTransfer<ContactNetwork, VectorType>::combineSubMatrices(const std::vector<const MatrixType*>& submat,
                                                                      const std::vector<const MatrixType*>& mortarTransferOperator,
                                                                      const CoordSystemVector& fineLocalCoordSystems,
                                                                      const std::vector<const Dune::BitSetVector<1>*>& fineHasObstacle,
                                                                      const std::vector<std::array<int,2> >& gridIdx)
{
    // ///////////////////////////////////
    //   Combine the submatrices
    // ///////////////////////////////////

    const size_t nGrids     = submat.size();
    const size_t nCouplings = mortarTransferOperator.size();

    Dune::MatrixVector::BlockMatrixView<MatrixType> view(submat);

    Dune::MatrixIndexSet totalIndexSet(view.nRows(), view.nCols());

    // import indices of canonical transfer operator
    for (size_t i=0; i<nGrids; i++)
        totalIndexSet.import(*submat[i], view.row(i, 0), view.col(i, 0));

    // ///////////////////////////////////////////////////////////////
    //   Add additional matrix entries  $ -D^{-1}M I_{mm} $
    // ///////////////////////////////////////////////////////////////

    typedef typename OperatorType::row_type RowType;
    typedef typename RowType::ConstIterator ConstIterator;

    std::vector<std::vector<int> > localToGlobal(nCouplings);

    for (size_t i=0; i<nCouplings; i++) {

      if (fineHasObstacle[i]->size() != submat[gridIdx[i][0]]->N())
        DUNE_THROW(Dune::Exception,
                   "fineHasObstacle[" << i << "] doesn't have the proper length!");

      localToGlobal[i].resize(mortarTransferOperator[i]->N());
      size_t idx = 0;
      for (size_t j=0; j<fineHasObstacle[i]->size(); j++)
        if ((*fineHasObstacle[i])[j][0])
          localToGlobal[i][idx++] = j;

      assert(idx==localToGlobal[i].size());

      for (size_t j=0; j<mortarTransferOperator[i]->N(); j++) {

        ConstIterator cTIt = (*mortarTransferOperator[i])[j].begin();
        ConstIterator cEndIt = (*mortarTransferOperator[i])[j].end();
        for (; cTIt != cEndIt; ++cTIt) {

          int k = cTIt.index();

          ConstIterator cMIt = (*submat[gridIdx[i][1]])[k].begin();
          ConstIterator cMEndIt = (*submat[gridIdx[i][1]])[k].end();
          for (; cMIt != cMEndIt; ++cMIt)
            totalIndexSet.add(view.row(gridIdx[i][0], localToGlobal[i][j]),
                view.col(gridIdx[i][1], cMIt.index()));

        }

      }

    }

    totalIndexSet.exportIdx(this->matrix_);
    this->matrix_ = 0;

    // Copy matrices
    for (size_t i=0; i<nGrids; i++) {

        for(size_t rowIdx=0; rowIdx<submat[i]->N(); rowIdx++) {

            const RowType& row = (*submat[i])[rowIdx];

            ConstIterator cIt    = row.begin();
            ConstIterator cEndIt = row.end();

            for(; cIt!=cEndIt; ++cIt)
                this->matrix_[view.row(i, rowIdx)][view.col(i, cIt.index())] = *cIt;

        }

    }

    // ///////////////////////////////////////////////////////////////
    //   Add additional matrix entries  $ -D^{-1}M I_{mm} $
    // ///////////////////////////////////////////////////////////////

    for (size_t i=0; i<nCouplings; i++) {

        for (size_t j=0; j<mortarTransferOperator[i]->N(); j++) {

            ConstIterator cTIt = (*mortarTransferOperator[i])[j].begin();
            ConstIterator cTEndIt = (*mortarTransferOperator[i])[j].end();

            for (; cTIt != cTEndIt; ++cTIt) {

                int k = cTIt.index();

                ConstIterator cMIt = (*submat[gridIdx[i][1]])[k].begin();
                ConstIterator cMEndIt = (*submat[gridIdx[i][1]])[k].end();

                for (; cMIt != cMEndIt; ++cMIt) {

                    auto& currentMatrixBlock = this->matrix_[view.row(gridIdx[i][0], localToGlobal[i][j])][view.col(gridIdx[i][1], cMIt.index())];

                    // the entry in the prolongation matrix of the mortar grid
                    const auto& subMatBlock = this->matrix_[view.row(gridIdx[i][1], k)][view.col(gridIdx[i][1], cMIt.index())];

                    // the - is due to the negation formula for BT
                    Dune::MatrixVector::subtractProduct(currentMatrixBlock,(*cTIt), subMatBlock);
                }
            }
        }
    }

    // ///////////////////////////////////////////////////////////////
    // Transform matrix to account for the local coordinate systems
    // ///////////////////////////////////////////////////////////////

    /** \todo Hack.  Those should really be equal */
    assert(fineLocalCoordSystems.size() <= this->matrix_.N());

    for(size_t rowIdx=0; rowIdx<fineLocalCoordSystems.size(); rowIdx++) {

        auto& row = this->matrix_[rowIdx];

        for(auto& col : row)
            col.leftmultiply(fineLocalCoordSystems[rowIdx]);

    }

}