from typing import Dict
from math import inf

import networkx as nx
from evrouting.T import Node, SoC, Time
from evrouting.utils import PriorityQueue
from evrouting.charge import factories as factories

from ..graph_tools import distance
from .T import SoCFunction, Label
from .utils import LabelPriorityQueue


def shortest_path(G: nx.Graph, charging_stations: set, s: Node, t: Node,
                  initial_soc: SoC, final_soc: SoC, capacity: SoC):
    """
    Calculates shortest path using the CHarge algorithm.

    :param G: Input Graph
    :param s: Start Node identifier
    :param t: End Node identifier
    :param beta_s: Start SoC
    :param beta_t: End SoC
    :param U: Capacity
    :return:
    """
    q = PriorityQueue()
    l_set: Dict[int, set] = {v: set() for v in G}
    l_uns: Dict[int, LabelPriorityQueue] = {
        v: LabelPriorityQueue() for v in G
    }

    # Dummy vertex without incident edges that is (temporarily) added to G
    dummy_node: Node = len(G.nodes)
    # Charging coefficient 0 indicates dummy node
    G.add_node(dummy_node, c=0)
    charging_stations.add(dummy_node)

    l: Label = Label(
        t_trip=0,
        soc_last_cs=initial_soc,
        last_cs=dummy_node,
        soc_profile_cs_v=factories.soc_profile(G, capacity, s)
    )

    l_uns[s].insert(
        l,
        factories.charging_function(G, l.last_cs, capacity, initial_soc)
    )

    q.insert(s, 0)

    # run main loop
    while True:
        try:
            minimum_node: Node = q.peak_min()
        except KeyError:
            # empty queue
            break

        label_minimum_node: Label = l_uns[minimum_node].delete_min()
        l_set[minimum_node].add(label_minimum_node)

        if minimum_node == t:
            return SoCFunction(
                label_minimum_node,
                factories.charging_function(
                    G,
                    label_minimum_node.last_cs,
                    capacity,
                    initial_soc
                )
            ).minimum

        # handle charging stations
        if minimum_node in charging_stations and not minimum_node == label_minimum_node.last_cs:
            cf_last_cs = factories.charging_function(
                G,
                label_minimum_node.last_cs,
                capacity,
                initial_soc  # Use here in case cs is a dummy station
            )
            cf_minimum_node = factories.charging_function(
                G,
                minimum_node,
                capacity,
                initial_soc  # Use here in case cs is a dummy station
            )

            if cf_minimum_node.c > cf_last_cs.c:
                # Only charge the minimum at the last charge station
                # and continue charging at this station.
                old_soc_function: SoCFunction = SoCFunction(
                    label_minimum_node, cf_last_cs
                )
                t_trip_old = label_minimum_node.t_trip
                t_charge: Time = old_soc_function.minimum - t_trip_old

                label_new = Label(
                    t_trip=t_trip_old + t_charge,
                    soc_last_cs=old_soc_function(t_trip_old + t_charge),
                    last_cs=minimum_node,
                    soc_profile_cs_v=factories.soc_profile(
                        G, capacity, minimum_node
                    )
                )
                l_uns[minimum_node].insert(
                    label_new,
                    cf_minimum_node
                )

        # update priority queue
        try:
            label_minimum_node = l_uns[minimum_node].peak_min()
        except KeyError:
            # l_uns[v] empty
            q.delete_min()
        else:
            q.insert(minimum_node, key(label_minimum_node))

        # scan outgoing arcs
        for n in G.neighbors(minimum_node):
            # Create SoC Profile for getting from minimum_node to n
            soc_profile = label_minimum_node.soc_profile_cs_v + \
                          factories.soc_profile(G, capacity, minimum_node, n)
            if not soc_profile(capacity) == -inf:
                # It is possible to get from minimum_node to n
                l_new = Label(
                    label_minimum_node.t_trip + distance(G, minimum_node, n),
                    label_minimum_node.soc_last_cs,
                    label_minimum_node.last_cs,
                    soc_profile
                )
                try:
                    l_uns[n].insert(
                        l_new,
                        factories.charging_function(
                            G,
                            l_new.last_cs,
                            capacity,
                            initial_soc
                        )
                    )
                except ValueError:
                    pass
                else:
                    if l_new == l_uns[n].peak_min():
                        q.insert(n, key(l_new))


def key(l: Label) -> Time:
    return l.t_trip