from typing import Set, Callable, List

import networkx as nx
from evrouting.T import Node, SoC
from evrouting.graph_tools import CONSUMPTION_KEY, DISTANCE_KEY

Path = List[Node]
DistFunction = Callable[[nx.Graph, Node, Node], Path]


def shortest_path(G: nx.Graph, s, t, b_0: float, b_t: float, U: float):
    """
    Calculates shortest path using a generalized gas station algorithm.

    :param G: Input Graph
    :param s: Start Node identifier
    :param t: End Node identifier
    :param b_0: Start SoC
    :param b_t: End SoC
    :param U: Capacity
    :return:
    """
    pass


def dijkstra(G: nx.Graph, u: Node, v: Node, weight: str = 'weight') -> Path:
    return nx.algorithms.shortest_path(G, u, v, weight=weight)


def fold_path(G: nx.Graph, path: Path, weight: str):
    return sum([G.edges[u, v][weight] for u, v in zip(path[:-1], path[1:])])


def contract_graph(G: nx.Graph, charging_stations: Set[Node], capacity: SoC,
                   dist: DistFunction = dijkstra) -> nx.Graph:
    """
    :param G: Original graph
    :param charging_stations: Charging stations
    :param capacity: Maximum battery capacity
    :param dist: Minimum distance function, necessary if G is not a fully
        connected Graph to calculate consumptions between charging stations.
    :returns: Graph only consisting of Charging Stations whose neighbours must
        be within the capacity U. If so, their edge has consumption and
        distance of the minimum path.
    """
    H: nx.Graph = nx.Graph()

    for cs in list(charging_stations):
        H.add_node(cs, **G.nodes[cs])
        # Iterate unvisited charging stations
        for n_cs in [n for n in charging_stations if (n, cs) not in H.edges]:
            min_path: Path = dist(G, cs, n_cs)
            consumption: SoC = fold_path(G, min_path, weight=CONSUMPTION_KEY)
            if consumption <= capacity:
                H.add_edge(
                    cs, n_cs,
                    **{
                        CONSUMPTION_KEY: consumption,
                        DISTANCE_KEY:fold_path(G, min_path, weight=DISTANCE_KEY)
                    }
                )
    return H