Convert a Open Street Maps `.map` format file into a networkx directional graph.

This parser is based on the osm to networkx tool
from Loïc Messal (github : Tofull)

Added :
- python3.6 compatibility
- networkx v2 compatibility
- cache to avoid downloading the same osm tiles again and again
- distance computation to estimate length of each ways (useful to compute the shortest path)

import copy
import xml.sax
import logging
import itertools

import networkx as nx
import rtree

from evrouting.graph_tools import CHARGING_COEFFICIENT_KEY, DISTANCE_KEY
from evrouting.osm.const import ms_to_kmh
from evrouting.osm.profiles import speed
from evrouting.osm.routing import point, haversine_distance

logger = logging.getLogger(__name__)

HAVERSINE_KEY = 'haversine'

class OSMGraph(nx.DiGraph):
    Adding some functionality to the graph for convenience when
    working with actual geo data from osm, such as a spatial index and
    a method to get an entry node by the spacial index for some coordinates.


    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.charging_stations = set()

        # Data structures for spatial index.
        # Rtree only supports indexing by integers, but the
        # node ids are strings. Therefore a mapping is introduced.
        self._rtree = rtree.index.Index()
        self._int_index = itertools.count()
        self._int_index_map = {}

    def insert_into_rtree(self, node):
        """Insert node into rtree."""
        info = self.nodes[node]
        lat, lon = info['lat'], info['lon']
        index = next(self._int_index)
        self._rtree.insert(index, (lon, lat, lon, lat))
        self._int_index_map[index] = node

    def rebuild_rtree(self):
        """Rebuild tree. Necessary because it cannot be pickled."""
        self._rtree = rtree.index.Index()
        self._int_index = itertools.count()
        self._int_index_map = {}
        for n in self.nodes:

    def insert_charging_stations(self, charging_stations):
        """Insert Charging Stations"""
        S = set()
        for s in charging_stations:
            lon = s['lon']
            lat = s['lat']
            n = self.find_nearest((lat, lon), distance_limit=500)
            if n is not None:
                self.nodes[n][CHARGING_COEFFICIENT_KEY] = s['power']

        self.charging_stations = S

    def find_nearest(self, v: point, distance_limit=None):
        Find nearest point to location v within radius
        of distance_limit.
        lat_v, lon_v = v

        index_n = next(self._rtree.nearest(
            (lon_v, lat_v, lon_v, lat_v), 1
        n = self._int_index_map[index_n]

        if distance_limit is not None:
            d = haversine_distance(
            if d > distance_limit:
                n = None

        return n

def read_osm(osm_xml_data, profile) -> OSMGraph:
    Read graph in OSM format from file specified by name or by stream object.
    Create Graph containing all highways as edges.


    osm = OSM(osm_xml_data)
    G = OSMGraph()

    # Add ways
    for w in osm.ways.values():
        if 'highway' not in w.tags:
        if w.tags['highway'] not in profile['highway_whitelist']:

        for u_id, v_id in zip(w.nds[:-1], w.nds[1:]):
            u, v = osm.nodes[u_id], osm.nodes[v_id]

            # Travel-time from u to v
            d = haversine_distance(u.lon, u.lat, v.lon, v.lat, unit_m=True)  # in m
            t = d / (speed(w, profile) / ms_to_kmh)  # in s

            if w.tags.get('oneway', 'no') == 'yes':
                # ONLY ONE DIRECTION
                G.add_edge(u_id, v_id, **{
                    DISTANCE_KEY: t,
                    HAVERSINE_KEY: d
                # BOTH DIRECTION
                G.add_edge(u_id, v_id, **{
                    DISTANCE_KEY: t,
                    HAVERSINE_KEY: d
                G.add_edge(v_id, u_id, **{
                    DISTANCE_KEY: t,
                    HAVERSINE_KEY: d

    # Complete the used nodes' information
    for n_id, data in G.nodes(data=True):
        n = osm.nodes[n_id]
        data['lat'] = n.lat
        data['lon'] = n.lon
        data['id'] = n.id

    return G

class Node(object):
    def __init__(self, id, lon, lat):
        self.id = id
        self.lon = lon
        self.lat = lat
        self.tags = {}

    def __str__(self):
        return "Node (id : %s) lon : %s, lat : %s " % (self.id, self.lon, self.lat)

class Way(object):
    def __init__(self, id):
        self.id = id
        self.nds = []
        self.tags = {}

    def split(self, node_pass_count):
        Slice way at every crossing i.e. when a waypoint is passend by
        multiple ways.

        # slice the node-array using this nifty recursive function
        def slice_array(waypoints):
            for i in range(1, len(waypoints) - 1):
                if node_pass_count[waypoints[i]] > 1:
                    left = waypoints[:i + 1]
                    right = waypoints[i:]

                    rightsliced = slice_array(right)

                    return [left] + rightsliced
            return [waypoints]

        slices = slice_array(self.nds)

        # create a way object for each node-array slice
        ret = []
        i = 0
        for slice in slices:
            littleway = copy.copy(self)
            littleway.id += "-%d" % i
            littleway.nds = slice
            i += 1

        return ret

class OSM(object):
    def __init__(self, osm_xml_data):
        """ File can be either a filename or stream/file object.

        set `is_xml_string=False` if osm_xml_data is a filename or a file stream.
        nodes = {}
        ways = {}

        class OSMHandler(xml.sax.ContentHandler):
            def __init__(self):
                self.currElem = None

            def setDocumentLocator(self, loc):

            def startDocument(self):

            def endDocument(self):

            def startElement(self, name, attrs):
                if name == 'node':
                    self.currElem = Node(attrs['id'], float(attrs['lon']), float(attrs['lat']))
                elif name == 'way':
                    self.currElem = Way(attrs['id'])
                elif name == 'tag':
                    self.currElem.tags[attrs['k']] = attrs['v']
                elif name == 'nd':

            def endElement(self, name):
                if name == 'node':
                    nodes[self.currElem.id] = self.currElem
                elif name == 'way':
                    ways[self.currElem.id] = self.currElem

            def characters(self, chars):

        xml.sax.parse(osm_xml_data, OSMHandler)

        self.nodes = nodes
        self.ways = ways

        # count times each node is used
        node_histogram = dict.fromkeys(self.nodes.keys(), 0)
        for way in self.ways.values():
            if len(way.nds) < 2:  # if a way has only one node, delete it out of the osm collection
                del self.ways[way.id]
                for node in way.nds:
                    node_histogram[node] += 1

        # use that histogram to split all ways, replacing the member set of ways
        new_ways = {}
        for id, way in self.ways.items():
            split_ways = way.split(node_histogram)
            for split_way in split_ways:
                new_ways[split_way.id] = split_way
        self.ways = new_ways