From 067c10fd78674b9fdd117a6073b7feccdd813090 Mon Sep 17 00:00:00 2001
From: "niehues.mark@gmail.com" <niehues.mark@gmail.com>
Date: Mon, 27 Apr 2020 19:25:32 +0200
Subject: [PATCH] wip

---
 evrouting/osm/imports.py     | 55 +++++++++++++++++++++++-------------
 tests/osm/test_osm_charge.py |  2 +-
 2 files changed, 37 insertions(+), 20 deletions(-)

diff --git a/evrouting/osm/imports.py b/evrouting/osm/imports.py
index 16c598d..6eb3e3d 100644
--- a/evrouting/osm/imports.py
+++ b/evrouting/osm/imports.py
@@ -16,40 +16,52 @@ import copy
 import xml.sax
 import logging
 import itertools
-from collections import namedtuple
 
 import networkx as nx
 import rtree
 
-from evrouting.graph_tools import CHARGING_COEFFICIENT_KEY, CONSUMPTION_KEY, DISTANCE_KEY
+from evrouting.graph_tools import CHARGING_COEFFICIENT_KEY, DISTANCE_KEY
 from evrouting.osm.const import ms_to_kmh
 from evrouting.osm.profiles import speed
 from evrouting.osm.routing import point, haversine_distance
 
 logger = logging.getLogger(__name__)
 
-OsrmConf = namedtuple('OsrmConf', ['server', 'port', 'version', 'profile'],
-                      defaults=('v1', 'driving'))
 
+class OSMGraph(nx.DiGraph):
+    """
+    Adding some functionality to the graph for convenience when
+    working with actual geo data from osm, such as a spatial index and
+    a method to get an entry node by the spacial index for some coordinates.
 
-def query_url(service, coordinates, osrm_config: OsrmConf):
-    """Construct query url."""
-    return f'http://{osrm_config.server}:{osrm_config.port}' \
-           f'/{service}/{osrm_config.version}/{osrm_config.profile}/' \
-           f'{";".join([f"{lon},{lat}" for lat, lon in coordinates])}'
-
+    """
 
-class OSMGraph(nx.DiGraph):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.charging_stations = set()
+
+        # Data structures for spatial index.
+        # Rtree only supports indexing by integers, but the
+        # node ids are strings. Therefore a mapping is introduced.
         self._rtree = rtree.index.Index()
         self._int_index = itertools.count()
+        self._int_index_map = {}
 
-    def update_rtree(self, node):
+    def insert_into_rtree(self, node):
+        """Insert node into rtree."""
         info = self.nodes[node]
         lat, lon = info['lat'], info['lon']
-        self._rtree.insert(next(self._int_index), (lon, lat, lon, lat), obj=node)
+        index = next(self._int_index)
+        self._rtree.insert(index, (lon, lat, lon, lat))
+        self._int_index_map[index] = node
+
+    def rebuild_rtree(self):
+        """Rebuild tree. Necessary because it cannot be pickled."""
+        self._rtree = rtree.index.Index()
+        self._int_index = itertools.count()
+        self._int_index_map = {}
+        for n in self.nodes:
+            self.insert_into_rtree(n)
 
     def insert_charging_stations(self, charging_stations):
         """Insert Charging Stations"""
@@ -71,9 +83,10 @@ class OSMGraph(nx.DiGraph):
         """
         lat_v, lon_v = v
 
-        n = next(self._rtree.nearest(
-            (lon_v, lat_v, lon_v, lat_v), 1, objects=True
-        )).object
+        index_n = next(self._rtree.nearest(
+            (lon_v, lat_v, lon_v, lat_v), 1
+        ))
+        n = self._int_index_map[index_n]
 
         if distance_limit is not None:
             d = haversine_distance(
@@ -89,7 +102,7 @@ class OSMGraph(nx.DiGraph):
         return n
 
 
-def read_osm(osm_xml_data, profile) -> nx.DiGraph:
+def read_osm(osm_xml_data, profile) -> OSMGraph:
     """
     Read graph in OSM format from file specified by name or by stream object.
     Create Graph containing all highways as edges.
@@ -121,7 +134,7 @@ def read_osm(osm_xml_data, profile) -> nx.DiGraph:
                 })
             else:
                 # BOTH DIRECTION
-                G.add_edge(u_id, v_id,**{
+                G.add_edge(u_id, v_id, **{
                     DISTANCE_KEY: d
                 })
                 G.add_edge(v_id, u_id, **{
@@ -134,7 +147,7 @@ def read_osm(osm_xml_data, profile) -> nx.DiGraph:
         G.nodes[n_id]['lat'] = n.lat
         G.nodes[n_id]['lon'] = n.lon
         G.nodes[n_id]['id'] = n.id
-        G.update_rtree(n_id)
+        G.insert_into_rtree(n_id)
 
     return G
 
@@ -199,6 +212,10 @@ class OSM(object):
         ways = {}
 
         class OSMHandler(xml.sax.ContentHandler):
+            def __init__(self):
+                super().__init__()
+                self.currElem = None
+
             @classmethod
             def setDocumentLocator(self, loc):
                 pass
diff --git a/tests/osm/test_osm_charge.py b/tests/osm/test_osm_charge.py
index becf5d5..94be86a 100644
--- a/tests/osm/test_osm_charge.py
+++ b/tests/osm/test_osm_charge.py
@@ -21,7 +21,7 @@ def graph():
         lat, lon = coordinates
         # Add two nodes, that exist in osm test map
         G.add_node(n_id, lat=lat, lon=lon)
-        G.update_rtree(n_id)
+        G.insert_into_rtree(n_id)
 
     yield G
     del G
-- 
GitLab