From 02cdd26f97cc6b69e0763369847f363f694e2671 Mon Sep 17 00:00:00 2001
From: "niehues.mark@gmail.com" <niehues.mark@gmail.com>
Date: Wed, 29 Apr 2020 15:26:42 +0200
Subject: [PATCH]  better dependecie injection

---
 evrouting/gasstation/routing.py          | 102 ++++++++++-------------
 evrouting/graph_tools.py                 |   4 +-
 tests/gasstation/test_transformations.py |  11 +--
 3 files changed, 49 insertions(+), 68 deletions(-)

diff --git a/evrouting/gasstation/routing.py b/evrouting/gasstation/routing.py
index 03e87c7..324dacc 100644
--- a/evrouting/gasstation/routing.py
+++ b/evrouting/gasstation/routing.py
@@ -2,16 +2,13 @@ from typing import Set, List
 
 import networkx as nx
 from evrouting.T import Node, SoC, Result, EmptyResult, Time
-from evrouting.gasstation.T import State
+from evrouting.gasstation.T import State, DistFunction
 from evrouting.graph_tools import (
     CONSUMPTION_KEY,
     DISTANCE_KEY,
-    consumption,
-    distance,
+    AccessFunctions
 )
 
-from evrouting.graph_tools import sum_weights as fold_path
-
 
 def insert_start_node(s: Node,
                       graph_core: nx.Graph,
@@ -20,27 +17,26 @@ def insert_start_node(s: Node,
                       graph_extended: nx.DiGraph,
                       capacity: SoC,
                       initial_soc: SoC,
-                      c: float = 1.,
-                      charging_coefficient=charging_cofficient
+                      f: AccessFunctions = AccessFunctions()
                       ) -> nx.DiGraph:
     """Insert s into extended graph an create states and edges as necessary."""
     graph_extended.add_node((s, initial_soc))
     v: Node
     for v in gas_stations:
-        shortest_p: List[Node] = nx.shortest_path(graph_core, s, v, weight=DISTANCE_KEY)
-        d = fold_path(graph_core, shortest_p, weight=DISTANCE_KEY)
-        w = c * d
+        shortest_p: List[Node] = nx.shortest_path(graph_core, s, v, weight=CONSUMPTION_KEY)
+        w = f.path_consumption(graph_core, shortest_p)
         if w > initial_soc:
             continue
 
-        c_v = charging_cofficient(graph_core, v)
+        d = f.path_distance(graph_core, shortest_p)
+        c_v = f.charging_coefficient(graph_core, v)
         g = initial_soc - w
 
         graph_extended.add_edge((s, initial_soc), (v, g), weight=d)
         for u in graph_contracted.neighbors(v):
-            c_u = charging_cofficient(graph_contracted, u)
-            w_v_u = consumption(graph_contracted, u, v)
-            d_v_u = distance(graph_contracted, u, v)
+            c_u = f.charging_coefficient(graph_contracted, u)
+            w_v_u = f.consumption(graph_contracted, u, v)
+            d_v_u = f.distance(graph_contracted, u, v)
             if c_v < c_u:
                 graph_extended.add_edge(
                     (v, g),
@@ -63,19 +59,19 @@ def insert_final_node(t: Node,
                       graph_extended: nx.DiGraph,
                       capacity: SoC,
                       final_soc: SoC,
-                      c: float = 1.
-                      charging_cofficient=charging_cofficient
+                      f: AccessFunctions = AccessFunctions()
                       ) -> nx.DiGraph:
     """Insert terminal node into extended graph an create states and edges as necessary."""
     graph_extended.add_node((t, final_soc))
     u: Node
     for u in gas_stations:
-        shortest_p: List[Node] = nx.shortest_path(graph_core, t, u, weight=DISTANCE_KEY)
-        d_u_t = fold_path(graph_core, shortest_p, weight=DISTANCE_KEY)
-        w = c * d_u_t
+        shortest_p: List[Node] = nx.shortest_path(graph_core, t, u, weight=CONSUMPTION_KEY)
+        w = f.path_consumption(graph_core, shortest_p)
         if w + final_soc > capacity:
             continue
-        c_u = charging_cofficient(graph_core, u)
+
+        d_u_t = f.path_distance(graph_core, shortest_p)
+        c_u = f.charging_coefficient(graph_core, u)
         for g in [g for n, g in graph_extended.nodes if n == u]:
             if g > w + final_soc:
                 continue
@@ -89,7 +85,7 @@ def insert_final_node(t: Node,
 
 
 def contract_graph(G: nx.Graph, charging_stations: Set[Node], capacity: SoC,
-                   c: float = 1.) -> nx.Graph:
+                   f: AccessFunctions = AccessFunctions()) -> nx.Graph:
     """
     :param G: Original graph
     :param charging_stations: Charging stations
@@ -110,14 +106,14 @@ def contract_graph(G: nx.Graph, charging_stations: Set[Node], capacity: SoC,
         H.add_node(cs, **G.nodes[cs])
         # Iterate unvisited charging stations
         for n_cs in all_cs[i + 1:]:
-            t_min = nx.algorithms.shortest_path_length(G, cs, n_cs, weight=DISTANCE_KEY)
-            w_cs_n: SoC = c * t_min
+            path = nx.algorithms.shortest_path(G, cs, n_cs, weight=DISTANCE_KEY)
+            w_cs_n: SoC = f.path_consumption(G, path)
             if w_cs_n <= capacity:
                 H.add_edge(
                     cs, n_cs,
                     **{
                         CONSUMPTION_KEY: w_cs_n,
-                        DISTANCE_KEY: t_min
+                        DISTANCE_KEY: f.path_distance(G, path)
                     }
                 )
     H.add_node(all_cs[-1], **G.nodes[all_cs[-1]])
@@ -125,39 +121,39 @@ def contract_graph(G: nx.Graph, charging_stations: Set[Node], capacity: SoC,
     return H
 
 
-def get_possible_arriving_soc(G: nx.Graph, u: Node, capacity: SoC) -> List[SoC]:
+def get_possible_arriving_soc(G: nx.Graph, u: Node, capacity: SoC, f: AccessFunctions = AccessFunctions()) -> List[SoC]:
     """
     :returns: All possible SoC  when arriving at node u, according to
         the optimal fuelling strategy.
     """
     possible_arriving_soc: Set[SoC] = {0}
-    c_u = charging_cofficient(G, u)
+    c_u = f.charging_coefficient(G, u)
 
     for n in G.neighbors(u):
-        arriving_soc = capacity - consumption(G, u, n)
-        if arriving_soc > 0 and charging_cofficient(G, n) < c_u and \
+        arriving_soc = capacity - f.consumption(G, u, n)
+        if arriving_soc > 0 and f.charging_coefficient(G, n) < c_u and \
                 arriving_soc not in possible_arriving_soc:
             possible_arriving_soc.add(arriving_soc)
 
     return list(possible_arriving_soc)
 
 
-def state_graph(G: nx.Graph, capacity: SoC) -> nx.DiGraph:
+def state_graph(G: nx.Graph, capacity: SoC, f: AccessFunctions = AccessFunctions()) -> nx.DiGraph:
     """Calculate Graph connecting (Node, Arrival SoC) states."""
     H: nx.DiGraph = nx.DiGraph()
 
     for u in G.nodes:
-        c_u = charging_cofficient(G, u)
+        c_u = f.charging_coefficient(G, u)
         for v in G.neighbors(u):
-            w = consumption(G, u, v)
+            w = f.consumption(G, u, v)
             if w <= capacity:
                 for g in get_possible_arriving_soc(G, u, capacity):
-                    c_v = charging_cofficient(G, v)
+                    c_v = f.charging_coefficient(G, v)
                     if c_v <= c_u and g < w:
-                        weight = (w - g) * c_u + distance(G, u, v)
+                        weight = (w - g) * c_u + f.distance(G, u, v)
                         H.add_edge((u, g), (v, 0), weight=weight)
                     elif c_v > c_u:
-                        weight = (capacity - g) * c_u + distance(G, u, v)
+                        weight = (capacity - g) * c_u + f.distance(G, u, v)
                         H.add_edge((u, g), (v, capacity - w), weight=weight)
 
     return H
@@ -176,9 +172,10 @@ def compose_result(graph_core: nx.Graph, extended_graph: nx.DiGraph,
         v, g_v = path[i + 1]
         t: Time = extended_graph.edges[(u, g_u), (v, g_v)]['weight']
         trip_time += t
-        charge_time_u: Time = t - fold_path(
+        charge_time_u: Time = t - nx.shortest_path_length(
             graph_core,
-            nx.shortest_path(graph_core, u, v, weight=DISTANCE_KEY),
+            u,
+            v,
             weight=DISTANCE_KEY
         )
         charge_path.append((u, charge_time_u))
@@ -188,17 +185,8 @@ def compose_result(graph_core: nx.Graph, extended_graph: nx.DiGraph,
     return Result(trip_time=trip_time, charge_path=charge_path)
 
 
-def shortest_path(G: nx.Graph,
-                  charging_stations: Set[Node],
-                  s: Node,
-                  t: Node,
-                  initial_soc: SoC,
-                  final_soc: SoC,
-                  capacity: SoC,
-                  c: float,
-                  extended_graph=None,
-                  contracted_graph=None
-                  ) -> Result:
+def shortest_path(G: nx.Graph, charging_stations: Set[Node], s: Node, t: Node,
+                  initial_soc: SoC, final_soc: SoC, capacity: SoC, f: AccessFunctions = AccessFunctions()) -> Result:
     """
     Calculates shortest path using a generalized gas station algorithm.
 
@@ -213,21 +201,19 @@ def shortest_path(G: nx.Graph,
     """
     # Check if t is reachable from s
     try:
-        _path = nx.shortest_path(G, s, t, weight=DISTANCE_KEY)
+        _path = nx.shortest_path(G, s, t, weight=CONSUMPTION_KEY)
     except nx.NetworkXNoPath:
         return EmptyResult()
 
-    _t = fold_path(G, _path, weight=DISTANCE_KEY)
-    _w = c * _t
+    _w = f.path_consumption(G, _path)
     if _w <= initial_soc:
         return Result(
-            trip_time=_t,
-            charge_path=[(n, 0) for n in _path]
+            trip_time=f.path_distance(G, _path),
+            charge_path=[(s, 0), (t, 0)]
         )
 
-    if extended_graph is None or contracted_graph is None:
-        contracted_graph: nx.Graph = contract_graph(G, charging_stations, capacity, c=c)
-        extended_graph = state_graph(contracted_graph, capacity)
+    contracted_graph: nx.Graph = contract_graph(G, charging_stations, capacity)
+    extended_graph = state_graph(contracted_graph, capacity)
 
     extended_graph = insert_start_node(
         s=s,
@@ -236,8 +222,7 @@ def shortest_path(G: nx.Graph,
         gas_stations=charging_stations,
         graph_extended=extended_graph,
         capacity=capacity,
-        initial_soc=initial_soc,
-        c=c
+        initial_soc=initial_soc
     )
 
     extended_graph = insert_final_node(
@@ -246,8 +231,7 @@ def shortest_path(G: nx.Graph,
         gas_stations=charging_stations,
         graph_extended=extended_graph,
         capacity=capacity,
-        final_soc=final_soc,
-        c=c
+        final_soc=final_soc
     )
 
     try:
diff --git a/evrouting/graph_tools.py b/evrouting/graph_tools.py
index b724014..6ae2d90 100644
--- a/evrouting/graph_tools.py
+++ b/evrouting/graph_tools.py
@@ -82,7 +82,7 @@ class AccessFunctions:
         self.charging_coefficient = charging_coefficient
 
     def path_distance(self, G, path):
-        sum_weights(G, path, weight=DISTANCE_KEY)
+        return sum_weights(G, path, weight=DISTANCE_KEY)
 
     def path_consumption(self, G, path):
-        sum_weights(G, path, weight=CONSUMPTION_KEY)
+        return sum_weights(G, path, weight=CONSUMPTION_KEY)
diff --git a/tests/gasstation/test_transformations.py b/tests/gasstation/test_transformations.py
index ea09b58..8091f9a 100644
--- a/tests/gasstation/test_transformations.py
+++ b/tests/gasstation/test_transformations.py
@@ -35,14 +35,12 @@ class TestContraction:
     @pytest.mark.parametrize('u,v,weight,value', [
         ('s', 'a', CONSUMPTION_KEY, 1),
         ('s', 'a', DISTANCE_KEY, 1),
-        ('s', 'f', CONSUMPTION_KEY, 1),  # Not exist
-        ('s', 'f', DISTANCE_KEY, 1),  # Not exist
+        ('s', 'f', '', None),  # Not exist
         ('s', 'b', '', None),  # Not exist
         ('s', 'd', '', None),  # Not exist
         ('s', 'c', CONSUMPTION_KEY, 2),
         ('s', 'c', DISTANCE_KEY, 2),
-        ('s', 'e', CONSUMPTION_KEY, 2),
-        ('s', 'e', DISTANCE_KEY, 2)
+        ('s', 'e', '', None),
     ])
     def test_contraction_edges(self, u, v, weight, value):
         """
@@ -56,7 +54,7 @@ class TestContraction:
         conf: dict = init_config(gasstation)
         H: nx.Graph = contract_graph(conf['G'],
                                      conf['charging_stations'],
-                                     conf['capacity'], c=1)
+                                     conf['capacity'])
 
         label_map = self.label_map(H)
         try:
@@ -242,8 +240,7 @@ class Integration:
         return contract_graph(
             G=graph_config['G'],
             charging_stations=graph_config['charging_stations'],
-            capacity=graph_config['capacity'],
-            c=.5
+            capacity=graph_config['capacity']
         )
 
     @pytest.fixture
-- 
GitLab