#include <cmath>

#include "sliplawstateupdater.hh"
#include "../../utils/tobool.hh"

template <class ScalarVector, class Vector>
SlipLawStateUpdater<ScalarVector, Vector>::SlipLawStateUpdater(
        const ScalarVector& alpha_initial,
        const BitVector& nodes,
        const double L,
        const double V0) :
    nodes_(nodes),
    L_(L),
    V0_(V0) {

    localToGlobal_.resize(nodes_.count());
    alpha_.resize(localToGlobal_.size());

    size_t localIdx = 0;
    for (size_t i=0; i<nodes_.size(); i++) {
        if (not toBool(nodes_[i]))
            continue;

        localToGlobal_[localIdx] = i;
        alpha_[localIdx] = alpha_initial[i];
        localIdx++;
    }
}

template <class ScalarVector, class Vector>
void SlipLawStateUpdater<ScalarVector, Vector>::nextTimeStep() {
    alpha_o_ = alpha_;
}

template <class ScalarVector, class Vector>
void SlipLawStateUpdater<ScalarVector, Vector>::setup(double tau) {
    tau_ = tau;
}

template <class ScalarVector, class Vector>
void SlipLawStateUpdater<ScalarVector, Vector>::solve(
    const Vector& velocity_field) {

    for (size_t i=0; i<alpha_.size(); ++i) {
        auto tangentVelocity = velocity_field[localToGlobal_[i]];
        tangentVelocity[0] = 0.0;

        double const V = tangentVelocity.two_norm();
        double const mtVoL = -tau_ * V / L_;
        alpha_[i] = (V <= 0) ? alpha_o_[i] : std::expm1(mtVoL) * std::log(V / V0_) +
                                               alpha_o_[i] * std::exp(mtVoL);
    }
}

template <class ScalarVector, class Vector>
void SlipLawStateUpdater<ScalarVector, Vector>::extractAlpha(
        ScalarVector& alpha) {

    assert(alpha.size() == nodes_.size());

    for (size_t i=0; i<localToGlobal_.size(); i++) {
        alpha[localToGlobal_[i]] = alpha_[i];
    }
}

template <class ScalarVector, class Vector>
auto SlipLawStateUpdater<ScalarVector, Vector>::clone() const
-> std::shared_ptr<LocalStateUpdater<ScalarVector, Vector>> {
  return std::make_shared<SlipLawStateUpdater<ScalarVector, Vector>>(*this);
}