"""
Implementation of the CHArge algorithm [0] with two further constraints:

    1. There are no negative path costs (ie no recurpation).
    2. All charging stations have linear charging functions.

[0] https://dl.acm.org/doi/10.1145/2820783.2820826

"""
from typing import Dict, List, Tuple, Set, Union
from math import inf

import networkx as nx
from evrouting.T import Node, SoC, Time
from evrouting.utils import PriorityQueue
from evrouting.graph_tools import distance
from evrouting.charge.T import SoCFunction, Label
from evrouting.charge.utils import LabelPriorityQueue
from evrouting.charge.factories import (
    ChargingFunctionMap,
    SoCFunctionFactory,
    SoCProfileFactory
)


def shortest_path(G: nx.Graph, charging_stations: Set[Node], s: Node, t: Node,
                  initial_soc: SoC, final_soc: SoC, capacity: SoC) -> Dict:
    """
    Calculates shortest path using the CHarge algorithm.

    :param G: Graph to work on
    :param charging_stations: Set containing identifiers of all
        charging stations
    :param s: Start Node
    :param t: End Node
    :param initial_soc: SoC at s
    :param final_soc: SoC at t
    :param capacity: Battery capacity

    :return:
    """
    t, factories, queues = _setup(
        G, charging_stations, capacity, initial_soc, final_soc, s, t
    )

    f_soc_factory: SoCFunctionFactory = factories['f_soc']
    soc_profile_factory: SoCProfileFactory = factories['soc_profile']
    cf_map: ChargingFunctionMap = factories['cf']

    l_set: Dict[int, List[Label]] = queues['settled labels']
    l_uns: Dict[int, LabelPriorityQueue] = queues['unsettled labels']
    prio_queue: PriorityQueue = queues['priority queue']

    # Shortcut for key function
    keys = LabelPriorityQueue.keys

    while prio_queue:
        node_min: Node = prio_queue.peak_min()

        label_node_min: Label = l_uns[node_min].delete_min()
        l_set[node_min].append(label_node_min)

        if node_min == t:
            return _result(
                label_node_min, f_soc_factory(label_node_min).minimum
            )

        # Handle charging stations
        if node_min in charging_stations and node_min != label_node_min.last_cs:
            f_soc: SoCFunction = f_soc_factory(label_node_min)
            t_charge = f_soc.calc_optimal_t_charge(cf_map[node_min])

            if t_charge is not None:
                # Spawn new label at t_charge
                l_uns[node_min].insert(
                    Label(
                        t_trip=label_node_min.t_trip + t_charge,
                        soc_last_cs=f_soc(label_node_min.t_trip + t_charge),
                        last_cs=node_min,
                        soc_profile_cs_v=soc_profile_factory(node_min),
                        parent_node=node_min,
                        parent_label=label_node_min
                    )
                )

        # Update priority queue. This node might have gotten a new
        # minimum label spawned is the previous step.
        try:
            prio_queue.insert(
                item=node_min,
                **keys(f_soc_factory(l_uns[node_min].peak_min()))
            )
        except KeyError:
            # l_uns[v] empty
            prio_queue.delete_min()

        # scan outgoing arcs
        for n in G.neighbors(node_min):
            # Create SoC Profile for getting from minimum_node to n
            soc_profile = label_node_min.soc_profile_cs_v + \
                          soc_profile_factory(node_min, n)

            if soc_profile(capacity) != -inf:
                if cf_map[label_node_min.last_cs].is_dummy \
                        and soc_profile.path_cost > label_node_min.soc_last_cs:
                    # Dummy charging stations cannot increase SoC.
                    # Therefore paths that consume more energy than the SoC
                    # when arriving at the (dummy) station are unfeasible.
                    continue

                label_neighbour: Label = Label(
                    t_trip=label_node_min.t_trip + distance(G, node_min, n),
                    soc_last_cs=label_node_min.soc_last_cs,
                    last_cs=label_node_min.last_cs,
                    soc_profile_cs_v=soc_profile,
                    parent_node=node_min,
                    parent_label=label_node_min
                )
                l_uns[n].insert(label_neighbour)

                # Update queue if entered label is the new minimum label
                # of the neighbour.
                try:
                    is_new_min: bool = label_neighbour == l_uns[n].peak_min()
                except KeyError:
                    continue

                if is_new_min:
                    prio_queue.insert(n, **keys(f_soc_factory(label_neighbour)))

    return _result()


def _setup(G: nx.Graph, charging_stations: Set[Node], capacity: SoC,
           initial_soc: SoC, final_soc: SoC, s: Node, t: Node
           ) -> Tuple[Node, Dict, Dict]:
    """
    Initialises the data structures and graph setup.

    :returns: Tupel(t, factories, queues):
        :t: The new dummy final node taking care of the final SoC.
        :factories: A dict containing factory functions for:
            :```factories['f_soc']```: The SoC Functions
            :```factories['cf']```: The Charging Functions
            :```factories['soc_profile']```: The SoC Profiles
        :queues: A dict containing initialized queues for the algorithm.
            :```queues['settled labels']```:
            :```queues['unsettled labels']```:
            :```queues['priority queue'']```:
    """
    # Add node that is only connected to the final node and takes no time
    # to travel but consumes exactly the amount of energy that should be
    # left at t (final_soc). The node becomes the new final node.
    dummy_final_node: Node = len(G)
    G.add_node(dummy_final_node)
    G.add_edge(t, dummy_final_node, weight=0, c=final_soc)
    t = dummy_final_node

    # Init factories
    cf_map = ChargingFunctionMap(G=G, capacity=capacity, initial_soc=initial_soc)
    f_soc_factory = SoCFunctionFactory(cf_map)
    soc_profile_factory = SoCProfileFactory(G, capacity)

    # Init maps to manage labels
    l_set: Dict[int, List[Label]] = {v: [] for v in G}
    l_uns: Dict[int, LabelPriorityQueue] = {
        v: LabelPriorityQueue(f_soc_factory, l_set[v]) for v in G
    }

    # Add dummy charging station with charging function
    # cf(t) = initial_soc (ie charging coefficient is zero).
    dummy_node: Node = len(G.nodes)
    G.add_node(dummy_node, c=0)
    charging_stations.add(dummy_node)

    # Register dummy charging station as the last
    # seen charging station before s.
    l_uns[s].insert(Label(
        t_trip=0,
        soc_last_cs=initial_soc,
        last_cs=dummy_node,
        soc_profile_cs_v=soc_profile_factory(s),
        parent_node=None,
        parent_label=None
    ))

    # A priority queue defines which node to visit next.
    # The key is the trip time.
    prio_queue: PriorityQueue = PriorityQueue()
    prio_queue.insert(s, priority=0, count=0)

    return (t,  # New final Node
            {  # factories
                'f_soc': f_soc_factory,
                'cf': cf_map,
                'soc_profile': soc_profile_factory
            },
            {  # queues
                'settled labels': l_set,
                'unsettled labels': l_uns,
                'priority queue': prio_queue
            }
            )


def _result(label: Label = None, f_soc_min: Time = None) -> Dict:
    """
    Returns a dict with two fields, as described below.

    :param label: The final label of the algorithm
    :param f_soc_min: The min time of the SoC Function of the final label
    :param node: The final node.

    :return Time result['trip_time']: The overall trip time ```f_soc_min```
    :return List[Tuple[Node, Time]] result['path']: List of Nodes and their
        according charging time along the path.
    """
    if any(arg is None for arg in [label, f_soc_min]):
        return {'trip_time': None, 'path': []}

    # Remember where charging time applies
    # First entry comes from the time necessary to charge at the last
    # charging stop to reach the goal.
    t_charge_map = {label.last_cs: f_soc_min - label.t_trip}

    # Skip inserted extra node
    node = label.parent_node
    label = label.parent_label

    path = []
    while label is not None:
        if node == label.parent_node:
            # Label got spawned at fixing t_charge of the parent's label
            # last_cs. For the current label holds: label.last_cs == node
            t_charge_map[label.parent_label.last_cs] = label.t_trip - label.parent_label.t_trip
        else:
            path.append((node, t_charge_map.get(node, 0)))
        node = label.parent_node
        label = label.parent_label

    return {'trip_time': f_soc_min, 'path': path[::-1]}