diff --git a/evrouting/gasstation/routing.py b/evrouting/gasstation/routing.py index ca26c2e19f67e6b87646dd8c151052d6344be196..b4038515e62d7128ef16f3d9a267d205975723b5 100644 --- a/evrouting/gasstation/routing.py +++ b/evrouting/gasstation/routing.py @@ -20,10 +20,16 @@ def insert_start_node(s: Node, graph_extended: nx.DiGraph, capacity: SoC, initial_soc: SoC, - f: AccessFunctions = AccessFunctions() + f: AccessFunctions = AccessFunctions(), + added_nodes=None ) -> nx.DiGraph: """Insert s into extended graph an create states and edges as necessary.""" + if added_nodes is None: + added_nodes = [] + + added_nodes.append((s, initial_soc)) graph_extended.add_node((s, initial_soc)) + v: Node for v in gas_stations: try: @@ -38,7 +44,8 @@ def insert_start_node(s: Node, d = f.path_distance(graph_core, shortest_p) c_v = f.charging_coefficient(graph_core, v) g = initial_soc - w - + if (v, g) not in graph_extended.nodes: + added_nodes.append((v, g)) graph_extended.add_edge((s, initial_soc), (v, g), weight=d) for u in graph_contracted.neighbors(v): c_u = f.charging_coefficient(graph_contracted, u) @@ -66,10 +73,16 @@ def insert_final_node(t: Node, graph_extended: nx.DiGraph, capacity: SoC, final_soc: SoC, - f: AccessFunctions = AccessFunctions() + f: AccessFunctions = AccessFunctions(), + added_nodes=None ) -> nx.DiGraph: """Insert terminal node into extended graph an create states and edges as necessary.""" + if added_nodes is None: + added_nodes = [] + graph_extended.add_node((t, final_soc)) + added_nodes.append((t, final_soc)) + u: Node for u in gas_stations: try: @@ -242,6 +255,8 @@ def shortest_path(G: nx.Graph, charge_path=[(n, 0) for n in _path] ) + added_nodes = [] + contracted_graph: nx.Graph = contracted_graph or contract_graph(G, charging_stations, capacity, f) extended_graph = extended_graph or state_graph(contracted_graph, capacity, f) @@ -253,7 +268,8 @@ def shortest_path(G: nx.Graph, graph_extended=extended_graph, capacity=capacity, initial_soc=initial_soc, - f=f + f=f, + added_nodes=added_nodes ) extended_graph = insert_final_node( @@ -263,17 +279,21 @@ def shortest_path(G: nx.Graph, graph_extended=extended_graph, capacity=capacity, final_soc=final_soc, - f=f + f=f, + added_nodes=added_nodes ) try: path: List[State] = nx.shortest_path(extended_graph, (s, initial_soc), (t, final_soc)) except nx.NetworkXNoPath: - return EmptyResult() + res = EmptyResult() + else: + res = compose_result( + graph_core=G, + extended_graph=extended_graph, + path=path, + f=f + ) - return compose_result( - graph_core=G, - extended_graph=extended_graph, - path=path, - f=f - ) + extended_graph.remove_nodes_from(added_nodes) + return res diff --git a/tests/gasstation/test_transformations.py b/tests/gasstation/test_transformations.py index 8091f9ae82ee8b4fa9ba8a5a7bdddd09ecf6b325..945e82a91967bed933409b69cf1fe8101c970a87 100644 --- a/tests/gasstation/test_transformations.py +++ b/tests/gasstation/test_transformations.py @@ -7,7 +7,8 @@ from evrouting.gasstation.routing import ( state_graph, insert_final_node, insert_start_node, - compose_result + compose_result, + shortest_path ) from evrouting.graph_tools import ( label, @@ -325,6 +326,21 @@ class TestIntegration(Integration): assert len(inserted_t.edges) == 11 assert inserted_t.edges[u, v]['weight'] == weight + def test_shortest_path_reset_state_graph(self, graph, graph_config, extended_graph, contracted_graph): + before = len(extended_graph.nodes) + shortest_path(G=graph, + charging_stations=graph_config['charging_stations'], + s=graph_config['s'], + t=graph_config['t'], + initial_soc=graph_config['initial_soc'], + final_soc=graph_config['final_soc'], + capacity=graph_config['capacity'], + extended_graph=extended_graph, + contracted_graph=contracted_graph + ) + + assert before == len(extended_graph.nodes) + class TestResult(Integration): def test_compose_result(self, graph, inserted_t):