from typing import List
from math import inf

import networkx as nx
from evrouting.T import Node, SoC, Time, ChargingCoefficient
from evrouting.utils import PriorityQueue

from .T import SoCProfile, ChargingFunction, Label


def shortest_path(G: nx.Graph, S: set, s: Node, t: Node, beta_s: SoC, beta_t: SoC, U: SoC):
    """
    Calculates shortest path using the CHarge algorithm.

    :param G: Input Graph
    :param s: Start Node identifier
    :param t: End Node identifier
    :param beta_s: Start SoC
    :param beta_t: End SoC
    :param U: Capacity
    :return:
    """
    q = PriorityQueue()
    l_set = {v: set() for v in G}
    l_uns = {v: PriorityQueue() for v in G}

    # Dummy vertex without incident edges that is (temporarily) added to G
    v_0: Node = Node(len(G.nodes))
    G.add_node(v_0)

    S.add(v_0)

    cf_v_0 = [(0, beta_s)]
    l_uns[s] = PriorityQueue()

    l = Label(0, beta_s, v_0, SoCProfile(G, U, s))
    l_uns[s].insert(item=l, priority=key(l))

    q.insert(s, 0)

    # run main loop
    while True:
        try:
            v = q.peak_min()
        except KeyError:
            # empty queue
            break

        l = l_uns[v].delete_min()
        l_set[v].add(l)

        if v == t:
            return ChargingFunction(G, l).get_minimum()

        # handle charging stations
        t_trip, beta_u, u, b_u_v = l
        if v in S and not v == u:
            # TODO !!!
            for t_charge in t_breaks(l):
                l_uns[v].insert(new_label(l), priority=)  # prio??

        # update priority queue
        if l_uns[v]:
            l_new = l_uns[v].peak_min()
            q.insert(v, key(l_new))
        else:
            q.delete_min()

        # scan outgoing arcs
        for x, y in G[v]:
            b_x_y = b_u_v + SoCProfile(G, U, x, y)
            if not b_x_y(beta_max_u) == -inf:
                l_new = (t_trip + G.edges[x, y]['weight'], beta_u, u, b_x_y)
                l_uns[y].insert(l_new)
                if l_new == l_uns[y].peak_min():
                    q.insert(y, key(l_new))


def key(l: Label) -> Time:
    return l.t_trip


def t_breaks(c_old: ChargingCoefficient, c_new: ChargingCoefficient) -> List[Time]:
    pass