"""
Convert a Open Street Maps `.map` format file into a networkx directional graph.

This parser is based on the osm to networkx tool
from Loïc Messal (github : Tofull)

Added :
- python3.6 compatibility
- networkx v2 compatibility
- cache to avoid downloading the same osm tiles again and again
- distance computation to estimate length of each ways (useful to compute the shortest path)

"""

import copy
import xml.sax
from math import radians, cos, sin, asin, sqrt
from collections import namedtuple

import networkx as nx
import aiohttp
import asyncio

from evrouting.graph_tools import DISTANCE_KEY

OsrmConf = namedtuple('OsrmConf',
                      ['server', 'port', 'version', 'profile'],
                      defaults=('v1', 'driving')
                      )


class CachedDistance:
    def __init__(self, graph, symmetric=True):
        self._cache = {}
        self.graph = graph
        self.symmetric = symmetric

    def d(self, u, v):
        raise NotImplemented

    def __getitem__(self, item):
        if self.symmetric:
            item = sorted(item)
        u, v = item

        try:
            return self._cache[u, v]
        except KeyError:
            d = self.d(u, v)
            self._cache[u, v] = d
            return d


class AsyncCachedOSRMDistance(CachedDistance):
    def __init__(self,
                 graph,
                 session,
                 symmetric=False,
                 osrm_config: OsrmConf = OsrmConf(server='0.0.0.0', port=5000)
                 ):
        super().__init__(graph, symmetric)
        self.session = session

        self.query_url = query_url

    async def d(self, u, v):
        loc_u = (self.graph[u]['lat'], self.graph[u]['long'])
        loc_v = (self.graph[v]['lat'], self.graph[v]['long'])

        async with self.session.get(self.query_url('route', [loc_u, loc_v])) as resp:
            return resp


def haversine_distance(lon1, lat1, lon2, lat2, unit_m=True):
    """
    Calculate the great circle distance between two points
    on the earth (specified in decimal degrees)
    default unit : km
    """
    # convert decimal degrees to radians
    lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])

    # haversine formula
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
    c = 2 * asin(sqrt(a))
    r = 6371  # Radius of the Earth in kilometers. Use 3956 for miles
    if unit_m:
        r *= 1000
    return c * r


def read_osm(osm_xml_data,
             osrm_config: OsrmConf = OsrmConf(server='localhost', port=5000)
             ) -> nx.DiGraph:
    """Read graph in OSM format from file specified by name or by stream object.
    Parameters
    ----------
    filename_or_stream : filename or stream object

    Returns
    -------
    G : Graph

    """
    only_roads = osrm_config.profile == 'driving'

    def query_url(service, coordinates):
        return f'http://{osrm_config.server}:{osrm_config.port}' \
               f'/{service}/{osrm_config.version}/{osrm_config.profile}/' \
               f'{";".join([f"{lon},{lat}" for lat, lon in coordinates])}'

    osm = OSM(osm_xml_data)
    G = nx.DiGraph()

    ## Add ways
    for w in osm.ways.values():
        if only_roads and 'highway' not in w.tags:
            continue

        if ('oneway' in w.tags):
            if (w.tags['oneway'] == 'yes'):
                # ONLY ONE DIRECTION
                nx.add_path(G, w.nds, id=w.id)
            else:
                # BOTH DIRECTION
                nx.add_path(G, w.nds, id=w.id)
                nx.add_path(G, w.nds[::-1], id=w.id)
        else:
            # BOTH DIRECTION
            nx.add_path(G, w.nds, id=w.id)
            nx.add_path(G, w.nds[::-1], id=w.id)

    # Complete the used nodes' information
    coordinates_map = {}
    for n_id in G.nodes():
        n = osm.nodes[n_id]
        G.nodes[n_id]['lat'] = n.lat
        G.nodes[n_id]['lon'] = n.lon
        G.nodes[n_id]['id'] = n.id
        coordinates_map[n_id] = (n.lon, n.lat)

    asyncio.run(augment_distances(G, query_url))
    G = nx.relabel_nodes(G, coordinates_map)
    return G


async def augment_distances(G, url_factory):
    # Estimate the length of each way
    async with aiohttp.ClientSession() as session:
        for u, v, d in G.edges(data=True):
            url = url_factory(
                'route',
                [
                    (G.nodes[u]['lat'], G.nodes[u]['lon']),
                    (G.nodes[v]['lat'], G.nodes[v]['lon'])
                ])
            async with session.get(url) as resp:
                resp.raise_for_status()
                resp = await resp.json()
                duration = resp['routes'][0]['duration']
                G.add_weighted_edges_from([(u, v, duration)], weight=DISTANCE_KEY)


class Node(object):
    def __init__(self, id, lon, lat):
        self.id = id
        self.lon = lon
        self.lat = lat
        self.tags = {}

    def __str__(self):
        return "Node (id : %s) lon : %s, lat : %s " % (self.id, self.lon, self.lat)


class Way(object):
    def __init__(self, id, osm):
        self.osm = osm
        self.id = id
        self.nds = []
        self.tags = {}

    def split(self, dividers):
        # slice the node-array using this nifty recursive function
        def slice_array(ar, dividers):
            for i in range(1, len(ar) - 1):
                if dividers[ar[i]] > 1:
                    left = ar[:i + 1]
                    right = ar[i:]

                    rightsliced = slice_array(right, dividers)

                    return [left] + rightsliced
            return [ar]

        slices = slice_array(self.nds, dividers)

        # create a way object for each node-array slice
        ret = []
        i = 0
        for slice in slices:
            littleway = copy.copy(self)
            littleway.id += "-%d" % i
            littleway.nds = slice
            ret.append(littleway)
            i += 1

        return ret


class OSM(object):
    def __init__(self, osm_xml_data):
        """ File can be either a filename or stream/file object.

        set `is_xml_string=False` if osm_xml_data is a filename or a file stream.
        """
        nodes = {}
        ways = {}

        superself = self

        class OSMHandler(xml.sax.ContentHandler):
            @classmethod
            def setDocumentLocator(self, loc):
                pass

            @classmethod
            def startDocument(self):
                pass

            @classmethod
            def endDocument(self):
                pass

            @classmethod
            def startElement(self, name, attrs):
                if name == 'node':
                    self.currElem = Node(attrs['id'], float(attrs['lon']), float(attrs['lat']))
                elif name == 'way':
                    self.currElem = Way(attrs['id'], superself)
                elif name == 'tag':
                    self.currElem.tags[attrs['k']] = attrs['v']
                elif name == 'nd':
                    self.currElem.nds.append(attrs['ref'])

            @classmethod
            def endElement(self, name):
                if name == 'node':
                    nodes[self.currElem.id] = self.currElem
                elif name == 'way':
                    ways[self.currElem.id] = self.currElem

            @classmethod
            def characters(self, chars):
                pass

        with open(osm_xml_data, mode='r') as f:
            xml.sax.parse(f, OSMHandler)

        self.nodes = nodes
        self.ways = ways

        # count times each node is used
        node_histogram = dict.fromkeys(self.nodes.keys(), 0)
        for way in self.ways.values():
            if len(way.nds) < 2:  # if a way has only one node, delete it out of the osm collection
                del self.ways[way.id]
            else:
                for node in way.nds:
                    node_histogram[node] += 1

        # use that histogram to split all ways, replacing the member set of ways
        new_ways = {}
        for id, way in self.ways.items():
            split_ways = way.split(node_histogram)
            for split_way in split_ways:
                new_ways[split_way.id] = split_way
        self.ways = new_ways