"""
Implementation of the CHArge algorithm [0] with two further constraints:

    1. There are no negative path costs (ie no recurpation).
    2. All charging stations have linear charging functions.

[0] https://dl.acm.org/doi/10.1145/2820783.2820826

"""
from typing import Dict, List, Tuple, Set
from math import inf

import networkx as nx
from evrouting.T import Node, SoC, Time, Result, EmptyResult, ConsumptionFunction
from evrouting.utils import PriorityQueue
from evrouting.graph_tools import distance, consumption, DISTANCE_KEY, CONSUMPTION_KEY
from evrouting.charge.T import SoCFunction, Label
from evrouting.charge.utils import LabelPriorityQueue
from evrouting.charge.factories import (
    ChargingFunctionMap,
    SoCFunctionFactory,
    SoCProfileFactory
)


def shortest_path(G: nx.DiGraph, charging_stations: Set[Node], s: Node, t: Node,
                  initial_soc: SoC, final_soc: SoC, capacity: SoC, c=consumption) -> Result:
    """
    Calculates shortest path using the CHarge algorithm.

    :param G: Graph to work on
    :param charging_stations: Set containing identifiers of all
        charging stations
    :param s: Start Node
    :param t: End Node
    :param initial_soc: SoC at s
    :param final_soc: SoC at t
    :param capacity: Battery capacity

    :return:
    """
    t, dummy_cs, factories, queues = _setup(
        G=G,
        charging_stations=charging_stations,
        capacity=capacity,
        initial_soc=initial_soc,
        final_soc=final_soc,
        s=s,
        t=t,
        c=c
    )

    f_soc_factory: SoCFunctionFactory = factories['f_soc']
    soc_profile_factory: SoCProfileFactory = factories['soc_profile']
    cf_map: ChargingFunctionMap = factories['cf']

    l_set: Dict[int, List[Label]] = queues['settled labels']
    l_uns: Dict[int, LabelPriorityQueue] = queues['unsettled labels']
    prio_queue: PriorityQueue = queues['priority queue']

    # Shortcut for key function
    keys = LabelPriorityQueue.keys

    while prio_queue:
        node_min: Node = prio_queue.peak_min()

        label_node_min: Label = l_uns[node_min].delete_min()
        l_set[node_min].append(label_node_min)

        if node_min == t:
            _cleanup(G, t, charging_stations, dummy_cs)
            return _result(
                label_node_min, f_soc_factory(label_node_min).minimum
            )

        # Handle charging stations
        if node_min in charging_stations and node_min != label_node_min.last_cs:
            f_soc: SoCFunction = f_soc_factory(label_node_min)
            t_charge = f_soc.calc_optimal_t_charge(cf_map[node_min])

            if t_charge is not None:
                # Spawn new label at t_charge
                l_uns[node_min].insert(
                    Label(
                        t_trip=label_node_min.t_trip + t_charge,
                        soc_last_cs=f_soc(label_node_min.t_trip + t_charge),
                        last_cs=node_min,
                        soc_profile_cs_v=soc_profile_factory(node_min),
                        parent_node=node_min,
                        parent_label=label_node_min
                    )
                )

        # Update priority queue. This node might have gotten a new
        # minimum label spawned is the previous step.
        try:
            prio_queue.insert(
                item=node_min,
                **keys(f_soc_factory(l_uns[node_min].peak_min()))
            )
        except KeyError:
            # l_uns[v] empty
            prio_queue.delete_min()

        # scan outgoing arcs
        for n in G.neighbors(node_min):
            # Create SoC Profile for getting from minimum_node to n
            soc_profile = label_node_min.soc_profile_cs_v + \
                          soc_profile_factory(node_min, n)

            if soc_profile(capacity) != -inf:
                if cf_map[label_node_min.last_cs].is_dummy \
                        and soc_profile.path_cost > label_node_min.soc_last_cs:
                    # Dummy charging stations cannot increase SoC.
                    # Therefore paths that consume more energy than the SoC
                    # when arriving at the (dummy) station are unfeasible.
                    continue

                label_neighbour: Label = Label(
                    t_trip=label_node_min.t_trip + distance(G, node_min, n),
                    soc_last_cs=label_node_min.soc_last_cs,
                    last_cs=label_node_min.last_cs,
                    soc_profile_cs_v=soc_profile,
                    parent_node=node_min,
                    parent_label=label_node_min
                )
                l_uns[n].insert(label_neighbour)

                # Update queue if entered label is the new minimum label
                # of the neighbour.
                try:
                    is_new_min: bool = label_neighbour == l_uns[n].peak_min()
                except KeyError:
                    continue

                if is_new_min:
                    prio_queue.insert(n, **keys(f_soc_factory(label_neighbour)))

    _cleanup(G, t, charging_stations, dummy_cs)
    return EmptyResult()


def _cleanup(G, t, charging_stations, dummy_cs):
    G.remove_node(t)
    G.remove_node(dummy_cs)
    charging_stations.remove(dummy_cs)


def _setup(G: nx.Graph, charging_stations: Set[Node], capacity: SoC,
           initial_soc: SoC, final_soc: SoC, s: Node, t: Node,
           c: ConsumptionFunction
           ) -> Tuple[Node, Node, Dict, Dict]:
    """
    Initialises the data structures and graph setup.

    :returns: Tupel(t, factories, queues):
        :t: The new dummy final node taking care of the final SoC.
        :factories: A dict containing factory functions for:
            :```factories['f_soc']```: The SoC Functions
            :```factories['cf']```: The Charging Functions
            :```factories['soc_profile']```: The SoC Profiles
        :queues: A dict containing initialized queues for the algorithm.
            :```queues['settled labels']```:
            :```queues['unsettled labels']```:
            :```queues['priority queue'']```:
    """
    # Add node that is only connected to the final node and takes no time
    # to travel but consumes exactly the amount of energy that should be
    # left at t (final_soc). The node becomes the new final node.
    dummy_final_node: Node = len(G)
    G.add_node(dummy_final_node)
    G.add_edge(t, dummy_final_node, **{
        DISTANCE_KEY: 0,
        CONSUMPTION_KEY: final_soc
    })
    t = dummy_final_node

    # Init factories
    cf_map = ChargingFunctionMap(G=G, capacity=capacity, initial_soc=initial_soc)
    f_soc_factory = SoCFunctionFactory(cf_map)
    soc_profile_factory = SoCProfileFactory(G, capacity, c)

    # Init maps to manage labels
    l_set: Dict[int, List[Label]] = {v: [] for v in G}
    l_uns: Dict[int, LabelPriorityQueue] = {
        v: LabelPriorityQueue(f_soc_factory, l_set[v]) for v in G
    }

    # Add dummy charging station with charging function
    # cf(t) = initial_soc (ie charging coefficient is zero).
    dummy_node: Node = len(G.nodes)
    G.add_node(dummy_node, c=0)
    charging_stations.add(dummy_node)

    # Register dummy charging station as the last
    # seen charging station before s.
    l_uns[s].insert(Label(
        t_trip=0,
        soc_last_cs=initial_soc,
        last_cs=dummy_node,
        soc_profile_cs_v=soc_profile_factory(s),
        parent_node=None,
        parent_label=None
    ))

    # A priority queue defines which node to visit next.
    # The key is the trip time.
    prio_queue: PriorityQueue = PriorityQueue()
    prio_queue.insert(s, priority=0, count=0)

    return (t,  # New final Node
            dummy_node,
            {  # factories
                'f_soc': f_soc_factory,
                'cf': cf_map,
                'soc_profile': soc_profile_factory
            },
            {  # queues
                'settled labels': l_set,
                'unsettled labels': l_uns,
                'priority queue': prio_queue
            }
            )


def _result(label: Label, f_soc_min: Time) -> Result:
    """
    Returns a dict with two fields, as described below.

    :param label: The final label of the algorithm
    :param f_soc_min: The min time of the SoC Function of the final label
    :param node: The final node.

    :return Time result['trip_time']: The overall trip time ```f_soc_min```
    :return List[Tuple[Node, Time]] result['path']: List of Nodes and their
        according charging time along the path.
    """
    # Remember where charging time applies
    # First entry comes from the time necessary to charge at the last
    # charging stop to reach the goal.
    t_charge_map = {label.last_cs: f_soc_min - label.t_trip}

    # Skip inserted extra node
    node = label.parent_node
    label = label.parent_label

    path = []
    while label is not None:
        if node == label.parent_node:
            # Label got spawned at fixing t_charge of the parent's label
            # last_cs. For the current label holds: label.last_cs == node
            t_charge_map[label.parent_label.last_cs] = label.t_trip - label.parent_label.t_trip
        else:
            path.append((node, t_charge_map.get(node, 0)))
        node = label.parent_node
        label = label.parent_label

    return Result(trip_time=f_soc_min, charge_path=path[::-1])