From 7358366c20953965a0aea12ead59bb6cd2836455 Mon Sep 17 00:00:00 2001
From: "niehues.mark@gmail.com" <niehues.mark@gmail.com>
Date: Thu, 19 Mar 2020 15:30:19 +0100
Subject: [PATCH] keeping the invariant

---
 evrouting/charge/T.py               | 36 +++++++++++++++++++----------
 evrouting/charge/routing.py         | 21 ++++++++++-------
 evrouting/charge/utils.py           | 25 +++++++++++++++++++-
 tests/charge/test_charge_routing.py |  8 +++++++
 4 files changed, 69 insertions(+), 21 deletions(-)

diff --git a/evrouting/charge/T.py b/evrouting/charge/T.py
index 7ba45fb..0fffe01 100644
--- a/evrouting/charge/T.py
+++ b/evrouting/charge/T.py
@@ -138,18 +138,6 @@ class ChargingFunction:
         """Comparison for dominance check."""
         return self.c < other.c
 
-    def __le__(self, other) -> bool:
-        """Comparison for dominance check."""
-        return self.c <= other.c
-
-    def __eq__(self, other) -> bool:
-        """Comparison for dominance check."""
-        return self.c == other.c
-
-    def __ge__(self, other):
-        """Comparison for dominance check."""
-        return self.c >= other.c
-
     def __gt__(self, other):
         """Comparison for dominance check."""
         return self.c > other.c
@@ -245,6 +233,30 @@ class SoCFunction:
             self.cf_cs(t - self.t_trip, self.soc_last_cs)
         )
 
+    def __lt__(self, other: 'SoCFunction') -> bool:
+        """Comparison for dominance check."""
+        for t_i, soc_i in self.breakpoints:
+            if other(t_i) < soc_i:
+                return False
+
+        for t_i, soc_i in other.breakpoints:
+            if soc_i < self(t_i):
+                return False
+
+        return True
+
+    def __gt__(self, other: 'SoCFunction') -> bool:
+        """Comparison for dominance check."""
+        for t_i, soc_i in self.breakpoints:
+            if other(t_i) > soc_i:
+                return False
+
+        for t_i, soc_i in other.breakpoints:
+            if soc_i > self(t_i):
+                return False
+
+        return True
+
     @property
     def minimum(self) -> Time:
         """
diff --git a/evrouting/charge/routing.py b/evrouting/charge/routing.py
index 7a63bbf..7c10419 100644
--- a/evrouting/charge/routing.py
+++ b/evrouting/charge/routing.py
@@ -1,4 +1,4 @@
-from typing import Dict, List
+from typing import Dict, List, Set
 from math import inf
 
 import networkx as nx
@@ -39,8 +39,8 @@ def shortest_path(G: nx.Graph, charging_stations: set, s: Node, t: Node,
     label_factory = LabelsFactory(G, capacity, f_soc, initial_soc)
 
     # Init maps to manage labels
-    l_set: Dict[int, set] = {v: set() for v in G}
-    l_uns: Dict[int, LabelPriorityQueue] = {v: LabelPriorityQueue(cf) for v in G}
+    l_set: Dict[int, Set[Label]] = {v: set() for v in G}
+    l_uns: Dict[int, LabelPriorityQueue] = {v: LabelPriorityQueue(cf, l_set[v]) for v in G}
 
     # Init environment
     entry_label = _create_entry_label(G, charging_stations,
@@ -102,11 +102,16 @@ def shortest_path(G: nx.Graph, charging_stations: set, s: Node, t: Node,
                     #
                     # That means, the SoC and thereby the range is restricted
                     # to the SoC at the last cs (soc_last_cs).
-                    pass
-                else:
-                    if l_new == l_uns[n].peak_min():
-                        key, count = _key(l_new, f_soc)
-                        prio_queue.insert(n, priority=key, count=count)
+                    continue
+
+                try:
+                    is_new_min_label: bool = l_new == l_uns[n].peak_min()
+                except KeyError:
+                    continue
+
+                if is_new_min_label:
+                    key, count = _key(l_new, f_soc)
+                    prio_queue.insert(n, priority=key, count=count)
 
 
 def _calc_optimal_t_charge(cf: ChargingFunctionMap, label_v: Label, v: Node, capacity: SoC) -> List[Time]:
diff --git a/evrouting/charge/utils.py b/evrouting/charge/utils.py
index 376dd83..ba79f57 100644
--- a/evrouting/charge/utils.py
+++ b/evrouting/charge/utils.py
@@ -1,3 +1,4 @@
+from typing import Set, Any
 from math import inf
 
 from evrouting.utils import PriorityQueue
@@ -8,9 +9,10 @@ from .factories import ChargingFunctionMap
 
 
 class LabelPriorityQueue(PriorityQueue):
-    def __init__(self, cf: ChargingFunctionMap):
+    def __init__(self, cf: ChargingFunctionMap, l_set: Set[Label]):
         super().__init__()
         self.cf: ChargingFunctionMap = cf
+        self.l_set: Set[Label] = l_set
 
     def insert(self, label: Label):
         """Breaking ties with lowest soc at t_min."""
@@ -32,3 +34,24 @@ class LabelPriorityQueue(PriorityQueue):
             priority=t_min,
             count=soc_min
         )
+
+        if self.peak_min() == label:
+            self.dominance_check()
+
+    def delete_min(self) -> Any:
+        min_label = super().delete_min()
+        self.dominance_check()
+        return min_label
+
+    def dominance_check(self):
+        try:
+            label: Label = self.peak_min()
+        except KeyError:
+            return
+
+        soc = SoCFunction(label, self.cf[label.last_cs])
+
+        for other_label in self.l_set:
+            if SoCFunction(other_label, self.cf[other_label.last_cs]) > soc:
+                self.remove_item(label)
+                return
diff --git a/tests/charge/test_charge_routing.py b/tests/charge/test_charge_routing.py
index 1f3a8f3..b01a7fc 100644
--- a/tests/charge/test_charge_routing.py
+++ b/tests/charge/test_charge_routing.py
@@ -49,6 +49,14 @@ class TestWithFinalSoC:
 
         assert path == 5
 
+    def test_path_impossilbe(self):
+        """Not possible to end with full battery."""
+        conf = init_config(edge_case)
+        conf['final_soc'] = 4
+        path = shortest_path(**conf)
+
+        assert path is None
+
     def test_shortest_path_charge_at_s_only(self):
         """Charging at s and a to reach final_soc."""
         conf = init_config(edge_case_a_slow)
-- 
GitLab