from typing import Dict
from math import inf

import networkx as nx
from evrouting.T import Node, SoC, Time
from evrouting.utils import PriorityQueue
from evrouting.charge.factories import soc_profile as soc_profile_factory

from ..graph_tools import distance
from .T import SoCFunction, Label
from .utils import LabelPriorityQueue, ChargingFunctionMap


def shortest_path(G: nx.Graph, charging_stations: set, s: Node, t: Node,
                  initial_soc: SoC, final_soc: SoC, capacity: SoC):
    """
    Calculates shortest path using the CHarge algorithm.

    :param G: Input Graph
    :param s: Start Node identifier
    :param t: End Node identifier
    :param beta_s: Start SoC
    :param beta_t: End SoC
    :param U: Capacity
    :return:
    """
    cf = ChargingFunctionMap(G=G, capacity=capacity, initial_soc=initial_soc)

    q = PriorityQueue()
    l_set: Dict[int, set] = {v: set() for v in G}
    l_uns: Dict[int, LabelPriorityQueue] = {
        v: LabelPriorityQueue() for v in G
    }

    # Dummy vertex without incident edges that is (temporarily) added to G
    dummy_node: Node = len(G.nodes)
    # Charging coefficient 0 indicates dummy node
    G.add_node(dummy_node, c=0)
    charging_stations.add(dummy_node)

    l: Label = Label(
        t_trip=0,
        soc_last_cs=initial_soc,
        last_cs=dummy_node,
        soc_profile_cs_v=soc_profile_factory(G, capacity, s)
    )

    l_uns[s].insert(
        l,
        cf[l.last_cs]
    )

    q.insert(s, 0)

    # run main loop
    while True:
        try:
            minimum_node: Node = q.peak_min()
        except KeyError:
            # empty queue
            break

        label_minimum_node: Label = l_uns[minimum_node].delete_min()
        l_set[minimum_node].add(label_minimum_node)

        if minimum_node == t:
            return SoCFunction(
                label_minimum_node,
                cf[label_minimum_node.last_cs]
            ).minimum

        # handle charging stations
        if minimum_node in charging_stations and not minimum_node == label_minimum_node.last_cs:
            cf_last_cs = cf[label_minimum_node.last_cs]
            cf_minimum_node = cf[minimum_node]

            if cf_minimum_node.c > cf_last_cs.c:
                # Only charge the minimum at the last charge station
                # and continue charging at this station.
                old_soc_function: SoCFunction = SoCFunction(
                    label_minimum_node, cf_last_cs
                )
                t_trip_old = label_minimum_node.t_trip
                t_charge: Time = old_soc_function.minimum - t_trip_old

                label_new = Label(
                    t_trip=t_trip_old + t_charge,
                    soc_last_cs=old_soc_function(t_trip_old + t_charge),
                    last_cs=minimum_node,
                    soc_profile_cs_v=soc_profile_factory(
                        G, capacity, minimum_node
                    )
                )
                l_uns[minimum_node].insert(
                    label_new,
                    cf_minimum_node
                )

        # update priority queue
        try:
            label_minimum_node = l_uns[minimum_node].peak_min()
        except KeyError:
            # l_uns[v] empty
            q.delete_min()
        else:
            q.insert(minimum_node, label_minimum_node.key)

        # scan outgoing arcs
        for n in G.neighbors(minimum_node):
            # Create SoC Profile for getting from minimum_node to n
            soc_profile = label_minimum_node.soc_profile_cs_v + \
                          soc_profile_factory(G, capacity, minimum_node, n)
            if not soc_profile(capacity) == -inf:
                # It is possible to get from minimum_node to n
                l_new = Label(
                    label_minimum_node.t_trip + distance(G, minimum_node, n),
                    label_minimum_node.soc_last_cs,
                    label_minimum_node.last_cs,
                    soc_profile
                )
                try:
                    l_uns[n].insert(
                        l_new,
                        cf[l_new.last_cs]
                    )
                except ValueError:
                    # Infeasible because last_cs might be an
                    # dummy charging station. Therefore, the path might
                    # be infeasible even though one could reach it with a full
                    # battery, because charging is not possible at dummy
                    # stations.
                    #
                    # That means, the SoC and thereby the range is restricted
                    # to the SoC at the last cs (soc_last_cs).
                    pass
                else:
                    if l_new == l_uns[n].peak_min():
                        q.insert(n, l_new.key)