#include "gradient_projection.h"

namespace GradientProjection {

void display(const gsl_vector* v, const char* name) {
  std::cout << name << " = <";
  for(unsigned int i=0; i<v->size; i++) {
    std::cout << gsl_vector_get(v, i);
    if (i < v->size - 1) {
      std::cout << ", ";
    }
  }
  std::cout << ">    (" << v->size << ")" << std::endl;
}

void display(const gsl_matrix* m, const char* name) {
  std::cout << name << "\t = |";
  for(unsigned int i=0; i<m->size1; i++) {
    if(i!=0) {
      std::cout << "\t   |";
    }
    for(unsigned int j=0; j<m->size2; j++) {
      std::cout << gsl_matrix_get(m, i, j) << "\t";
    }
    std::cout << "|" << std::endl;
  }
  std::cout << "                                SIZE: " << m->size1 << " x " << m->size2 << std::endl;
}

void createProjection(const gsl::matrix& activeConstraints,
                      const gsl::vector& g,
                      const gsl::vector& grad,
                      gsl::matrix& projection,
                      gsl::vector& direction,
                      gsl::vector& correction) {
  int n = activeConstraints.size1();
  int r = activeConstraints.size2();

  correction.resize(n);
  direction.resize(n);

  // This could be done with cholesky or QR decomposition, but I
  // couldn't get it to work.  Given that this happens infrequently
  // and the matrices are not *that* big, it's not that bad
  gsl::matrix S(r,r);
  // S = N^T N
  gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, 
                 activeConstraints.gslobj(), activeConstraints.gslobj(), 
                 0.0, S.gslobj());
  // T = (N^{T} N) ^{-1}
  gsl::matrix T = S.LU_invert();
  S.set_dimensions(n, r);
  // S = -N(N^{T} N)^{-1}
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, -1.0, 
                 activeConstraints.gslobj(), T.gslobj(), 0.0, S.gslobj());

  // Set the correction
  gsl_blas_dgemv(CblasNoTrans, 1.0, S.gslobj(), g.gslobj(), 0.0, 
                 correction.gslobj());

  // Set the direction
  // P = -N(N^{T} N)^{-1}N + I 
  projection.identity(n);
  gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, 
                 S.gslobj(), activeConstraints.gslobj(), 
                 1.0, projection.gslobj());
  gsl_blas_dgemv(CblasNoTrans, -1.0, projection.gslobj(), grad.gslobj(), 0.0, 
                 direction.gslobj());
}

bool createActiveConstraints(const gsl::vector& x, 
                             gsl::matrix& n, 
                             gsl::vector& g) {
  bool sumConstraintViolated = false;
  int dimension = x.size();
  double margin = SAFETY_BOX;

  if(x.sum() >= 1.0 - margin) {
    sumConstraintViolated = true;
  }

  int nonNegativeConstraintsViolated = 0;
  for(int ii = 0; ii < dimension; ++ii) {
    if (x[ii] <= margin) {
      ++nonNegativeConstraintsViolated;
    }
  }

  int newSize = nonNegativeConstraintsViolated;
  if(sumConstraintViolated) {
    newSize += 1;
  }

  if(newSize > 0) {
    n.set_dimensions(dimension, newSize);
    g.resize(newSize);
    g.set_all(SAFETY_BOX);
    int col = 0;
    if(sumConstraintViolated) {
      g[0] = -(1.0 - SAFETY_BOX);
      for(int ii = 0; ii < dimension; ++ii) {
        n(ii, 0) = -1.0;
      }
      ++col;
    }

    for(int ii = 0; ii < dimension; ++ii) {
      if(x[ii] <= margin) {
        n(ii, col) = 1.0;
        ++col;
      }
    }
    assert(col == newSize);

    gsl_blas_dgemv(CblasTrans, 1.0, n.gslobj(), x.gslobj(), -1.0, g.gslobj());

    //display(n.gslobj(), "N");
    //display(g.gslobj(), "g");

    return true;
  } else {
    return false;
  }
}

double descend(gsl::vector& x, 
               gsl::vector& s,
               const double gamma, 
               const double obj_value,             
               const gsl::vector& correction,
               const gsl::vector& grad) {
  double alpha = 0.0;

  gsl_blas_ddot(s.gslobj(), grad.gslobj(), &alpha);
  //std::cout << "dot prod= " << alpha << " ";
  if(alpha == 0) {
    return alpha;
  }
  alpha = -(gamma * obj_value) / alpha;
  //std::cout << " alpha= " << alpha << " ";

  s *= alpha;
  s += correction;
  x += s;

  //display(s.gslobj(), "final move");

  if(alpha < 0) {
    alpha = -alpha;
  }
  return alpha;
}

double updateState(gsl::vector& x,
                   const double gamma,
                   const gsl::vector grad,
                   const double f) {
  /*
   * First we see if we're up against constraints
   */
  int dim = x.size();
  gsl::matrix n;
  gsl::vector g;
  gsl::vector s;
  gsl::vector correction(dim);
  
  if(createActiveConstraints(x, n, g)) {
    s.resize(dim);
    gsl::matrix p;
    
    createProjection(n, g, grad, p, s, correction);
    //std::cout << "Constraints violated." << std::endl;

    //display(p.gslobj(), "p");
    //display(s.gslobj(), "s");
    //display(correction.gslobj(), "correction");
    return descend(x, s, gamma, f, correction, grad);
  } else {
    //std::cout << "No constraints violated." << std::endl;
    s.copy(grad);
    s *= -gamma * GRADIENT_DESCENT_SLOWDOWN;
    x += s;
    double diff;
    gsl_blas_ddot(s.gslobj(), s.gslobj(), &diff);
    return diff * gamma;
  }


}


}