import networkx as nx
from typing import Dict

from .T import SoCProfile, SoCFunction, ChargingFunction, Label
from ..T import Node, SoC, Time
from ..graph_tools import charging_cofficient, consumption


def charging_function_factory(
        G: nx.Graph,
        n: Node,
        capacity: SoC,
        initial_soc: SoC = None) -> ChargingFunction:
    """Create charging function of node."""
    return ChargingFunction(charging_cofficient(G, n), capacity, initial_soc)


def soc_profile_factory(
        G: nx.Graph,
        capacity: SoC,
        u: Node,
        v: Node = None,
) -> SoCProfile:
    """
    Return SoC Profile of the path from u to v.

    If no v is provided, the path of u is definded as no cost path.

    """
    path_cost = 0 if v is None else consumption(G, u, v)
    return SoCProfile(path_cost, capacity)


class ChargingFunctionMap:
    """Maps Nodes to their charging functions."""

    def __init__(self, G: nx.Graph, capacity: SoC, initial_soc: SoC = None):
        self.map: Dict[Node, ChargingFunction] = {}
        self.G: nx.Graph = G
        self.capacity: SoC = capacity
        self.initial_soc: SoC = initial_soc

    def __getitem__(self, node: Node) -> ChargingFunction:
        """
        Try to get charging function from cache,
        else create function and add to cache.
        """
        try:
            cf = self.map[node]
        except KeyError:
            cf = charging_function_factory(
                G=self.G,
                n=node,
                capacity=self.capacity,
                initial_soc=self.initial_soc
            )
            self.map[node] = cf

        return cf


class SoCFunctionMap:
    """Maps Nodes to their charging functions."""

    def __init__(self, cf: ChargingFunctionMap):
        self.cf: ChargingFunctionMap = cf

    def __getitem__(self, label: Label) -> SoCFunction:
        return SoCFunction(label, self.cf[label.last_cs])


class LabelsFactory:

    def __init__(self,
                 G: nx.Graph,
                 capacity: SoC,
                 cf: ChargingFunctionMap,
                 initial_soc: SoC = None):
        self.G: nx.Graph = G
        self.capacity: SoC = capacity
        self.cf: ChargingFunctionMap = cf
        self.initial_soc: SoC = initial_soc

    def spawn_label(self, current_node: Node, current_label: Label):
        # Only charge the minimum at the last charge station
        # and continue charging at this station.
        soc_function: SoCFunction = SoCFunction(
            current_label, self.cf[current_label.last_cs]
        )

        t_trip_old = current_label.t_trip
        t_charge: Time = soc_function.minimum - t_trip_old

        return Label(
            t_trip=t_trip_old + t_charge,
            soc_last_cs=soc_function(t_trip_old + t_charge),
            last_cs=current_node,
            soc_profile_cs_v=soc_profile_factory(
                self.G, self.capacity, current_node)
        )