import networkx as nx
from typing import Dict

from .T import SoCProfile, SoCFunction, ChargingFunction, Label
from ..T import Node, SoC, Time
from ..graph_tools import charging_cofficient, consumption


class SoCProfileFactory:
    """Maps Nodes to their (cached) charging functions."""

    def __init__(self, G: nx.Graph, capacity: SoC):
        self.G: nx.Graph = G
        self.capacity: SoC = capacity

    def __call__(self, u: Node, v: Node = None) -> SoCProfile:
        path_cost = 0 if v is None else consumption(self.G, u, v)

        return SoCProfile(path_cost, self.capacity)


class ChargingFunctionMap:
    """Maps Nodes to their (cached) charging functions."""

    def __init__(self, G: nx.Graph, capacity: SoC, initial_soc: SoC = None):
        self.map: Dict[Node, ChargingFunction] = {}
        self.G: nx.Graph = G
        self.capacity: SoC = capacity
        self.initial_soc: SoC = initial_soc

    def __getitem__(self, node: Node) -> ChargingFunction:
        """
        Try to get charging function from cache,
        else create function and add to cache.
        """
        try:
            cf = self.map[node]
        except KeyError:
            cf = ChargingFunction(
                c=charging_cofficient(self.G, node),
                capacity=self.capacity,
                initial_soc=self.initial_soc
            )
            self.map[node] = cf

        return cf


class SoCFunctionFactory:
    """Maps Nodes to their charging functions."""

    def __init__(self, cf: ChargingFunctionMap):
        self.cf: ChargingFunctionMap = cf

    def __call__(self, label: Label) -> SoCFunction:
        return SoCFunction(label, self.cf[label.last_cs])


class LabelsFactory:

    def __init__(self,
                 f_soc: SoCFunctionFactory,
                 soc_profile: SoCProfileFactory):
        self.f_soc: SoCFunctionFactory = f_soc
        self.soc_profile: SoCProfileFactory = soc_profile

    def spawn_label(self, current_node: Node, current_label: Label, t_charge: Time):
        # Only charge the minimum at the last charge station
        # and continue charging at this station.
        soc_function: SoCFunction = self.f_soc(current_label)

        return Label(
            t_trip=current_label.t_trip + t_charge,
            soc_last_cs=soc_function(current_label.t_trip + t_charge),
            last_cs=current_node,
            soc_profile_cs_v=self.soc_profile(current_node)
        )