From 050da9af7b37353a0e98e4bbffb600170961d86b Mon Sep 17 00:00:00 2001
From: "niehues.mark@gmail.com" <niehues.mark@gmail.com>
Date: Thu, 19 Mar 2020 16:29:50 +0100
Subject: [PATCH] consolidated factories

---
 evrouting/charge/factories.py | 56 ++++++++++++-----------------------
 evrouting/charge/routing.py   | 37 ++++++++++++-----------
 evrouting/charge/utils.py     | 18 +++++------
 3 files changed, 46 insertions(+), 65 deletions(-)

diff --git a/evrouting/charge/factories.py b/evrouting/charge/factories.py
index 3cdf2f8..15ffdb6 100644
--- a/evrouting/charge/factories.py
+++ b/evrouting/charge/factories.py
@@ -6,33 +6,21 @@ from ..T import Node, SoC, Time
 from ..graph_tools import charging_cofficient, consumption
 
 
-def charging_function_factory(
-        G: nx.Graph,
-        n: Node,
-        capacity: SoC,
-        initial_soc: SoC = None) -> ChargingFunction:
-    """Create charging function of node."""
-    return ChargingFunction(charging_cofficient(G, n), capacity, initial_soc)
+class SoCProfileFactory:
+    """Maps Nodes to their (cached) charging functions."""
 
+    def __init__(self, G: nx.Graph, capacity: SoC):
+        self.G: nx.Graph = G
+        self.capacity: SoC = capacity
 
-def soc_profile_factory(
-        G: nx.Graph,
-        capacity: SoC,
-        u: Node,
-        v: Node = None,
-) -> SoCProfile:
-    """
-    Return SoC Profile of the path from u to v.
-
-    If no v is provided, the path of u is definded as no cost path.
+    def __call__(self, u: Node, v: Node = None) -> SoCProfile:
+        path_cost = 0 if v is None else consumption(self.G, u, v)
 
-    """
-    path_cost = 0 if v is None else consumption(G, u, v)
-    return SoCProfile(path_cost, capacity)
+        return SoCProfile(path_cost, self.capacity)
 
 
 class ChargingFunctionMap:
-    """Maps Nodes to their charging functions."""
+    """Maps Nodes to their (cached) charging functions."""
 
     def __init__(self, G: nx.Graph, capacity: SoC, initial_soc: SoC = None):
         self.map: Dict[Node, ChargingFunction] = {}
@@ -48,9 +36,8 @@ class ChargingFunctionMap:
         try:
             cf = self.map[node]
         except KeyError:
-            cf = charging_function_factory(
-                G=self.G,
-                n=node,
+            cf = ChargingFunction(
+                c=charging_cofficient(self.G, node),
                 capacity=self.capacity,
                 initial_soc=self.initial_soc
             )
@@ -59,37 +46,32 @@ class ChargingFunctionMap:
         return cf
 
 
-class SoCFunctionMap:
+class SoCFunctionFactory:
     """Maps Nodes to their charging functions."""
 
     def __init__(self, cf: ChargingFunctionMap):
         self.cf: ChargingFunctionMap = cf
 
-    def __getitem__(self, label: Label) -> SoCFunction:
+    def __call__(self, label: Label) -> SoCFunction:
         return SoCFunction(label, self.cf[label.last_cs])
 
 
 class LabelsFactory:
 
     def __init__(self,
-                 G: nx.Graph,
-                 capacity: SoC,
-                 f_soc: SoCFunctionMap,
-                 initial_soc: SoC = None):
-        self.G: nx.Graph = G
-        self.capacity: SoC = capacity
-        self.f_soc: SoCFunctionMap = f_soc
-        self.initial_soc: SoC = initial_soc
+                 f_soc: SoCFunctionFactory,
+                 soc_profile: SoCProfileFactory):
+        self.f_soc: SoCFunctionFactory = f_soc
+        self.soc_profile: SoCProfileFactory = soc_profile
 
     def spawn_label(self, current_node: Node, current_label: Label, t_charge: Time):
         # Only charge the minimum at the last charge station
         # and continue charging at this station.
-        soc_function: SoCFunction = self.f_soc[current_label]
+        soc_function: SoCFunction = self.f_soc(current_label)
 
         return Label(
             t_trip=current_label.t_trip + t_charge,
             soc_last_cs=soc_function(current_label.t_trip + t_charge),
             last_cs=current_node,
-            soc_profile_cs_v=soc_profile_factory(
-                self.G, self.capacity, current_node)
+            soc_profile_cs_v=self.soc_profile(current_node)
         )
diff --git a/evrouting/charge/routing.py b/evrouting/charge/routing.py
index 7c10419..ab13fe8 100644
--- a/evrouting/charge/routing.py
+++ b/evrouting/charge/routing.py
@@ -7,8 +7,8 @@ from evrouting.utils import PriorityQueue
 from evrouting.charge.factories import (
     LabelsFactory,
     ChargingFunctionMap,
-    SoCFunctionMap,
-    soc_profile_factory
+    SoCFunctionFactory,
+    SoCProfileFactory
 )
 
 from ..graph_tools import distance
@@ -34,17 +34,19 @@ def shortest_path(G: nx.Graph, charging_stations: set, s: Node, t: Node,
     """
     t = _apply_final_constraints(G, t, final_soc)
 
-    cf = ChargingFunctionMap(G=G, capacity=capacity, initial_soc=initial_soc)
-    f_soc = SoCFunctionMap(cf)
-    label_factory = LabelsFactory(G, capacity, f_soc, initial_soc)
+    # Init factories
+    cf_map = ChargingFunctionMap(G=G, capacity=capacity, initial_soc=initial_soc)
+    f_soc_factory = SoCFunctionFactory(cf_map)
+    soc_profile_factory = SoCProfileFactory(G, capacity)
+    label_factory = LabelsFactory(f_soc_factory, soc_profile_factory)
 
     # Init maps to manage labels
     l_set: Dict[int, Set[Label]] = {v: set() for v in G}
-    l_uns: Dict[int, LabelPriorityQueue] = {v: LabelPriorityQueue(cf, l_set[v]) for v in G}
+    l_uns: Dict[int, LabelPriorityQueue] = {v: LabelPriorityQueue(f_soc_factory, l_set[v]) for v in G}
 
     # Init environment
     entry_label = _create_entry_label(G, charging_stations,
-                                      s, initial_soc, capacity)
+                                      s, initial_soc, soc_profile_factory)
     l_uns[s].insert(entry_label)
 
     # A priority queue defines which node to visit next.
@@ -63,12 +65,12 @@ def shortest_path(G: nx.Graph, charging_stations: set, s: Node, t: Node,
         l_set[minimum_node].add(label_minimum_node)
 
         if minimum_node == t:
-            return f_soc[label_minimum_node].minimum
+            return f_soc_factory(label_minimum_node).minimum
 
         # handle charging stations
         if minimum_node in charging_stations and \
                 not minimum_node == label_minimum_node.last_cs:
-            for t_charge in _calc_optimal_t_charge(cf, label_minimum_node, minimum_node, capacity):
+            for t_charge in _calc_optimal_t_charge(cf_map, label_minimum_node, minimum_node, capacity):
                 label_new = label_factory.spawn_label(minimum_node,
                                                       label_minimum_node,
                                                       t_charge)
@@ -76,13 +78,13 @@ def shortest_path(G: nx.Graph, charging_stations: set, s: Node, t: Node,
 
         # Update priority queue. This node might have gotten a new
         # minimum label spawned is th previous step.
-        _update_priority_queue(f_soc, prio_queue, l_uns, minimum_node)
+        _update_priority_queue(f_soc_factory, prio_queue, l_uns, minimum_node)
 
         # scan outgoing arcs
         for n in G.neighbors(minimum_node):
             # Create SoC Profile for getting from minimum_node to n
             soc_profile = label_minimum_node.soc_profile_cs_v + \
-                          soc_profile_factory(G, capacity, minimum_node, n)
+                          soc_profile_factory(minimum_node, n)
 
             if _is_feasible_path(soc_profile, capacity):
                 l_new = Label(
@@ -110,7 +112,7 @@ def shortest_path(G: nx.Graph, charging_stations: set, s: Node, t: Node,
                     continue
 
                 if is_new_min_label:
-                    key, count = _key(l_new, f_soc)
+                    key, count = _key(l_new, f_soc_factory)
                     prio_queue.insert(n, priority=key, count=count)
 
 
@@ -130,8 +132,8 @@ def _calc_optimal_t_charge(cf: ChargingFunctionMap, label_v: Label, v: Node, cap
     return t_charge
 
 
-def _key(label, f_soc):
-    soc_function = f_soc[label]
+def _key(label, f_soc_factory):
+    soc_function = f_soc_factory(label)
 
     t_min = soc_function.minimum
     soc_min = soc_function(t_min)
@@ -144,7 +146,8 @@ def _create_entry_label(
         charging_stations: set,
         s: Node,
         initial_soc: SoC,
-        capacity: SoC) -> Label:
+        soc_profile_factory: SoCProfileFactory
+) -> Label:
     """
     Create dummy charging station with initial soc as constant charging
     function.
@@ -168,7 +171,7 @@ def _create_entry_label(
         t_trip=0,
         soc_last_cs=initial_soc,
         last_cs=dummy_node,
-        soc_profile_cs_v=soc_profile_factory(G, capacity, s)
+        soc_profile_cs_v=soc_profile_factory(s)
     )
 
 
@@ -178,7 +181,7 @@ def _is_feasible_path(soc_profile: SoCProfile, capacity: SoC) -> bool:
 
 
 def _update_priority_queue(
-        f_soc: SoCFunctionMap,
+        f_soc: SoCFunctionFactory,
         prio_queue: PriorityQueue,
         l_uns: Dict[int, LabelPriorityQueue],
         node: Node):
diff --git a/evrouting/charge/utils.py b/evrouting/charge/utils.py
index ba79f57..bc42556 100644
--- a/evrouting/charge/utils.py
+++ b/evrouting/charge/utils.py
@@ -4,23 +4,19 @@ from math import inf
 from evrouting.utils import PriorityQueue
 from evrouting.T import SoC, Time
 
-from .T import Label, SoCFunction
-from .factories import ChargingFunctionMap
+from .T import Label
+from .factories import SoCFunctionFactory
 
 
 class LabelPriorityQueue(PriorityQueue):
-    def __init__(self, cf: ChargingFunctionMap, l_set: Set[Label]):
+    def __init__(self, f_soc: SoCFunctionFactory, l_set: Set[Label]):
         super().__init__()
-        self.cf: ChargingFunctionMap = cf
+        self.f_soc_factory: SoCFunctionFactory = f_soc
         self.l_set: Set[Label] = l_set
 
     def insert(self, label: Label):
         """Breaking ties with lowest soc at t_min."""
-        soc_function = SoCFunction(
-            label,
-            self.cf[label.last_cs]
-        )
-
+        soc_function = self.f_soc_factory(label)
         t_min: Time = soc_function.minimum
 
         # Might happen because of dummy charge stations
@@ -49,9 +45,9 @@ class LabelPriorityQueue(PriorityQueue):
         except KeyError:
             return
 
-        soc = SoCFunction(label, self.cf[label.last_cs])
+        soc = self.f_soc_factory(label)
 
         for other_label in self.l_set:
-            if SoCFunction(other_label, self.cf[other_label.last_cs]) > soc:
+            if self.f_soc_factory(other_label) > soc:
                 self.remove_item(label)
                 return
-- 
GitLab