from typing import Dict, List, Set
from math import inf

import networkx as nx
from evrouting.T import Node, SoC, Time
from evrouting.utils import PriorityQueue
from evrouting.charge.factories import (
    LabelsFactory,
    ChargingFunctionMap,
    SoCFunctionMap,
    soc_profile_factory
)

from ..graph_tools import distance
from .T import SoCProfile, SoCFunction, Label
from .utils import LabelPriorityQueue

__all__ = ['shortest_path']


def shortest_path(G: nx.Graph, charging_stations: set, s: Node, t: Node,
                  initial_soc: SoC, final_soc: SoC, capacity: SoC):
    """
    Calculates shortest path using the CHarge algorithm.

    :param G:
    :param charging_stations:
    :param s:
    :param t:
    :param initial_soc:
    :param final_soc:
    :param capacity:
    :return:
    """
    t = _apply_final_constraints(G, t, final_soc)

    cf = ChargingFunctionMap(G=G, capacity=capacity, initial_soc=initial_soc)
    f_soc = SoCFunctionMap(cf)
    label_factory = LabelsFactory(G, capacity, f_soc, initial_soc)

    # Init maps to manage labels
    l_set: Dict[int, Set[Label]] = {v: set() for v in G}
    l_uns: Dict[int, LabelPriorityQueue] = {v: LabelPriorityQueue(cf, l_set[v]) for v in G}

    # Init environment
    entry_label = _create_entry_label(G, charging_stations,
                                      s, initial_soc, capacity)
    l_uns[s].insert(entry_label)

    # A priority queue defines which node to visit next.
    # The key is the trip time.
    prio_queue = PriorityQueue()
    prio_queue.insert(s, priority=0, count=0)

    while True:
        try:
            minimum_node: Node = prio_queue.peak_min()
        except KeyError:
            # empty queue
            break

        label_minimum_node: Label = l_uns[minimum_node].delete_min()
        l_set[minimum_node].add(label_minimum_node)

        if minimum_node == t:
            return f_soc[label_minimum_node].minimum

        # handle charging stations
        if minimum_node in charging_stations and \
                not minimum_node == label_minimum_node.last_cs:
            for t_charge in _calc_optimal_t_charge(cf, label_minimum_node, minimum_node, capacity):
                label_new = label_factory.spawn_label(minimum_node,
                                                      label_minimum_node,
                                                      t_charge)
                l_uns[minimum_node].insert(label_new)

        # Update priority queue. This node might have gotten a new
        # minimum label spawned is th previous step.
        _update_priority_queue(f_soc, prio_queue, l_uns, minimum_node)

        # scan outgoing arcs
        for n in G.neighbors(minimum_node):
            # Create SoC Profile for getting from minimum_node to n
            soc_profile = label_minimum_node.soc_profile_cs_v + \
                          soc_profile_factory(G, capacity, minimum_node, n)

            if _is_feasible_path(soc_profile, capacity):
                l_new = Label(
                    t_trip=label_minimum_node.t_trip + distance(G, minimum_node, n),
                    soc_last_cs=label_minimum_node.soc_last_cs,
                    last_cs=label_minimum_node.last_cs,
                    soc_profile_cs_v=soc_profile
                )
                try:
                    l_uns[n].insert(l_new)
                except ValueError:
                    # Infeasible because last_cs might be an
                    # dummy charging station. Therefore, the path might
                    # be infeasible even though one could reach it with a full
                    # battery, because charging is not possible at dummy
                    # stations.
                    #
                    # That means, the SoC and thereby the range is restricted
                    # to the SoC at the last cs (soc_last_cs).
                    continue

                try:
                    is_new_min_label: bool = l_new == l_uns[n].peak_min()
                except KeyError:
                    continue

                if is_new_min_label:
                    key, count = _key(l_new, f_soc)
                    prio_queue.insert(n, priority=key, count=count)


def _calc_optimal_t_charge(cf: ChargingFunctionMap, label_v: Label, v: Node, capacity: SoC) -> List[Time]:
    f_soc_breakpoints = SoCFunction(label_v, cf[label_v.last_cs]).breakpoints
    t_charge = []

    if cf[v] > cf[label_v.last_cs]:
        # Faster charging station -> charge as soon as possible
        t_charge.append(f_soc_breakpoints[0].t - label_v.t_trip)
    elif f_soc_breakpoints[-1].soc < capacity:
        # Slower charging station might still be dominating
        # because the soc cannot be more than the full capacity
        # decreased by the trip costs. This will be refilled at this station.
        t_charge.append(f_soc_breakpoints[-1].t - label_v.t_trip)

    return t_charge


def _key(label, f_soc):
    soc_function = f_soc[label]

    t_min = soc_function.minimum
    soc_min = soc_function(t_min)

    return t_min, soc_min


def _create_entry_label(
        G: nx.Graph,
        charging_stations: set,
        s: Node,
        initial_soc: SoC,
        capacity: SoC) -> Label:
    """
    Create dummy charging station with initial soc as constant charging
    function.

    :param G: Graph
    :param charging_stations: Set of charging stations in Graph G
    :param s: Starting Node
    :param initial_soc: Initial SoC at beginng of the route
    :param capacity: The restricting battery capacity
    :return: Label for the starting Node
    """
    dummy_node: Node = len(G.nodes)

    # Charging coefficient 0 indicates dummy node
    G.add_node(dummy_node, c=0)
    charging_stations.add(dummy_node)

    # Register dummy charging station as the last
    # seen charging station before s.
    return Label(
        t_trip=0,
        soc_last_cs=initial_soc,
        last_cs=dummy_node,
        soc_profile_cs_v=soc_profile_factory(G, capacity, s)
    )


def _is_feasible_path(soc_profile: SoCProfile, capacity: SoC) -> bool:
    """Check, if possible to traverse path at least with full battery."""
    return not soc_profile(capacity) == -inf


def _update_priority_queue(
        f_soc: SoCFunctionMap,
        prio_queue: PriorityQueue,
        l_uns: Dict[int, LabelPriorityQueue],
        node: Node):
    """
    Update key of a node the priority queue according to
    its minimum label.
    """
    try:
        minimum_label: Label = l_uns[node].peak_min()
    except KeyError:
        # l_uns[v] empty
        prio_queue.delete_min()
    else:
        key, count = _key(minimum_label, f_soc)
        prio_queue.insert(node, priority=key, count=count)


def _apply_final_constraints(G: nx.Graph, t: Node, final_soc: SoC) -> Node:
    temp_final_node = len(G)
    G.add_node(temp_final_node)
    G.add_edge(t, temp_final_node, weight=0, c=final_soc)

    return temp_final_node