from typing import Set, List

import networkx as nx
from evrouting.T import Node, SoC, Result, EmptyResult, Time
from evrouting.gasstation.T import State
from evrouting.graph_tools import (
    CONSUMPTION_KEY,
    DISTANCE_KEY,
    AccessFunctions
)


def insert_start_node(s: Node,
                      graph_core: nx.Graph,
                      graph_contracted: nx.Graph,
                      gas_stations: Set[Node],
                      graph_extended: nx.DiGraph,
                      capacity: SoC,
                      initial_soc: SoC,
                      f: AccessFunctions = AccessFunctions()
                      ) -> nx.DiGraph:
    """Insert s into extended graph an create states and edges as necessary."""
    graph_extended.add_node((s, initial_soc))
    v: Node
    for v in gas_stations:
        try:
            shortest_p: List[Node] = f.shortest_path(graph_core, s, v)
        except nx.NetworkXNoPath:
            continue

        w = f.path_consumption(graph_core, shortest_p)
        if w > initial_soc:
            continue

        d = f.path_distance(graph_core, shortest_p)
        c_v = f.charging_coefficient(graph_core, v)
        g = initial_soc - w

        graph_extended.add_edge((s, initial_soc), (v, g), weight=d)
        for u in graph_contracted.neighbors(v):
            c_u = f.charging_coefficient(graph_contracted, u)
            w_v_u = f.consumption(graph_contracted, u, v)
            d_v_u = f.distance(graph_contracted, u, v)
            if c_v < c_u:
                graph_extended.add_edge(
                    (v, g),
                    (u, capacity - w_v_u),
                    weight=(capacity - g) * c_v + d_v_u
                )
            elif g < w_v_u:
                graph_extended.add_edge(
                    (v, g),
                    (u, 0),
                    weight=(w_v_u - g) * c_v + d_v_u
                )

    return graph_extended


def insert_final_node(t: Node,
                      graph_core: nx.Graph,
                      gas_stations: Set[Node],
                      graph_extended: nx.DiGraph,
                      capacity: SoC,
                      final_soc: SoC,
                      f: AccessFunctions = AccessFunctions()
                      ) -> nx.DiGraph:
    """Insert terminal node into extended graph an create states and edges as necessary."""
    graph_extended.add_node((t, final_soc))
    u: Node
    for u in gas_stations:
        try:
            shortest_p: List[Node] = f.shortest_path(graph_core, t, u)
        except nx.NetworkXNoPath:
            continue

        w = f.path_consumption(graph_core, shortest_p)
        if w + final_soc > capacity:
            continue

        d_u_t = f.path_distance(graph_core, shortest_p)
        c_u = f.charging_coefficient(graph_core, u)
        for g in [g for n, g in graph_extended.nodes if n == u]:
            if g > w + final_soc:
                continue
            graph_extended.add_edge(
                (u, g),
                (t, final_soc),
                weight=(w + final_soc - g) * c_u + d_u_t
            )

    return graph_extended


def contract_graph(G: nx.Graph, charging_stations: Set[Node], capacity: SoC,
                   f: AccessFunctions = AccessFunctions()) -> nx.Graph:
    """
    :param G: Original graph
    :param charging_stations: Charging stations
    :param capacity: Maximum battery capacity
    :param c: Linear coefficient to calc consumption from (Time) distance
    :returns: Graph only consisting of Charging Stations whose neighbours must
        be within the capacity U. If so, their edge has consumption and
        distance of the minimum path.
    """
    H: nx.Graph = nx.Graph()

    if not charging_stations:
        return H

    all_cs = list(charging_stations)
    for i in range(len(all_cs) - 1):
        cs = all_cs[i]
        H.add_node(cs, **G.nodes[cs])
        # Iterate unvisited charging stations
        for n_cs in all_cs[i + 1:]:
            try:
                path = f.shortest_path(G, cs, n_cs)
            except nx.NetworkXNoPath:
                continue
            w_cs_n: SoC = f.path_consumption(G, path)
            if w_cs_n <= capacity:
                H.add_edge(
                    cs, n_cs,
                    **{
                        CONSUMPTION_KEY: w_cs_n,
                        DISTANCE_KEY: f.path_distance(G, path)
                    }
                )
    H.add_node(all_cs[-1], **G.nodes[all_cs[-1]])

    return H


def get_possible_arriving_soc(G: nx.Graph, u: Node, capacity: SoC, f: AccessFunctions = AccessFunctions()) -> List[SoC]:
    """
    :returns: All possible SoC  when arriving at node u, according to
        the optimal fuelling strategy.
    """
    possible_arriving_soc: Set[SoC] = {0}
    c_u = f.charging_coefficient(G, u)

    for n in G.neighbors(u):
        arriving_soc = capacity - f.consumption(G, u, n)
        if arriving_soc > 0 and f.charging_coefficient(G, n) < c_u and \
                arriving_soc not in possible_arriving_soc:
            possible_arriving_soc.add(arriving_soc)

    return list(possible_arriving_soc)


def state_graph(G: nx.Graph, capacity: SoC, f: AccessFunctions = AccessFunctions()) -> nx.DiGraph:
    """Calculate Graph connecting (Node, Arrival SoC) states."""
    H: nx.DiGraph = nx.DiGraph()

    for u in G.nodes:
        c_u = f.charging_coefficient(G, u)
        for v in G.neighbors(u):
            w = f.consumption(G, u, v)
            if w <= capacity:
                for g in get_possible_arriving_soc(G, u, capacity):
                    c_v = f.charging_coefficient(G, v)
                    if c_v <= c_u and g < w:
                        weight = (w - g) * c_u + f.distance(G, u, v)
                        H.add_edge((u, g), (v, 0), weight=weight)
                    elif c_v > c_u:
                        weight = (capacity - g) * c_u + f.distance(G, u, v)
                        H.add_edge((u, g), (v, capacity - w), weight=weight)

    return H


def compose_result(graph_core: nx.Graph, extended_graph: nx.DiGraph,
                   path: List[State], f: AccessFunctions = AccessFunctions()) -> Result:
    trip_time: Time = 0
    charge_path = []
    u: Node
    v: Node
    g_u: SoC
    g_v: SoC
    for i in range(len(path) - 1):
        u, g_u = path[i]
        v, g_v = path[i + 1]
        t: Time = extended_graph.edges[(u, g_u), (v, g_v)]['weight']
        trip_time += t
        path_in_between = f.shortest_path(
            graph_core,
            u,
            v
        )
        charge_time_u: Time = t - f.path_distance(graph_core, path_in_between)
        charge_path.append((u, charge_time_u))
        charge_path += [(n, 0) for n in path_in_between[1:-1]]

    charge_path.append((path[-1][0], 0))  # Final Node

    return Result(trip_time=trip_time, charge_path=charge_path)


def shortest_path(G: nx.Graph,
                  charging_stations: Set[Node],
                  s: Node,
                  t: Node,
                  initial_soc: SoC,
                  final_soc: SoC,
                  capacity: SoC,
                  f: AccessFunctions = AccessFunctions(),
                  extended_graph=None,
                  contracted_graph=None
                  ) -> Result:
    """
    Calculates shortest path using a generalized gas station algorithm.

    :param G:
    :param charging_stations:
    :param s:
    :param t:
    :param initial_soc:
    :param final_soc:
    :param capacity:
    :return:
    """
    # Check if t is reachable from s
    try:
        _path = f.shortest_path(G, s, t)
    except nx.NetworkXNoPath:
        return EmptyResult()

    _w = f.path_consumption(G, _path)
    if _w <= initial_soc:
        return Result(
            trip_time=f.path_distance(G, _path),
            charge_path=[(n, 0) for n in _path]
        )

    contracted_graph: nx.Graph = contracted_graph or contract_graph(G, charging_stations, capacity, f)
    extended_graph = extended_graph or state_graph(contracted_graph, capacity, f)

    extended_graph = insert_start_node(
        s=s,
        graph_core=G,
        graph_contracted=contracted_graph,
        gas_stations=charging_stations,
        graph_extended=extended_graph,
        capacity=capacity,
        initial_soc=initial_soc,
        f=f
    )

    extended_graph = insert_final_node(
        t=t,
        graph_core=G,
        gas_stations=charging_stations,
        graph_extended=extended_graph,
        capacity=capacity,
        final_soc=final_soc,
        f=f
    )

    try:
        path: List[State] = nx.shortest_path(extended_graph, (s, initial_soc), (t, final_soc))
    except nx.NetworkXNoPath:
        return EmptyResult()

    return compose_result(
        graph_core=G,
        extended_graph=extended_graph,
        path=path,
        f=f
    )