from typing import Dict

import networkx as nx

from evrouting.T import Node, SoC, ConsumptionFunction
from evrouting.graph_tools import charging_cofficient
from evrouting.charge.T import SoCProfile, SoCFunction, ChargingFunction, Label


class SoCProfileFactory:
    """Maps Nodes to their (cached) charging functions."""

    def __init__(self, G: nx.Graph, capacity: SoC, c: ConsumptionFunction):
        """
        :param G:
        :param capacity:
        :param c: Function to calc consumption for an edge.
        """
        self.G: nx.Graph = G
        self.capacity: SoC = capacity
        self.c = c

    def __call__(self, u: Node, v: Node = None) -> SoCProfile:
        path_cost = 0 if v is None else self.c(self.G, u, v)

        return SoCProfile(path_cost, self.capacity)


class ChargingFunctionMap:
    """Maps Nodes to their (cached) charging functions."""

    def __init__(self, G: nx.Graph, capacity: SoC, initial_soc: SoC = None):
        self.map: Dict[Node, ChargingFunction] = {}
        self.G: nx.Graph = G
        self.capacity: SoC = capacity
        self.initial_soc: SoC = initial_soc

    def __getitem__(self, node: Node) -> ChargingFunction:
        """
        Try to get charging function from cache,
        else create function and add to cache.
        """
        try:
            cf = self.map[node]
        except KeyError:
            cf = ChargingFunction(
                c=charging_cofficient(self.G, node),
                capacity=self.capacity,
                initial_soc=self.initial_soc
            )
            self.map[node] = cf

        return cf


class SoCFunctionFactory:
    """Maps Nodes to their charging functions."""

    def __init__(self, cf: ChargingFunctionMap):
        self.cf: ChargingFunctionMap = cf

    def __call__(self, label: Label) -> SoCFunction:
        return SoCFunction(label, self.cf[label.last_cs])