import argparse
import json
import random
import pickle
import logging
from time import perf_counter
from pathlib import Path

import yaml
from evrouting.osm.imports import read_osm
from lib.T import *
from lib.export import write_head, write_row
from lib.queries import (
    gasstation_query, charge_query, classic_query, astar_query, CACHE
)

base = Path(__file__).parent


def query_benchmark(graphs, conf, result_dir):
    # Charging Stations
    cs_path = base.joinpath('static').joinpath(conf['charging_stations'])
    with cs_path.open() as f:
        charging_stations = json.load(f)

    query_conf = [
        Query(query_function=gasstation_query,
              filename='gasstation.csv',
              row_dataclass=GasstationQueryRow),
        Query(charge_query, 'charge.csv', ChargeQueryRow),
        Query(classic_query, 'classic.csv', ClassicQueryRow),
        Query(astar_query, 'astar.csv', AStarQueryRow)
    ]

    # Remove existing results
    for _, filename, _ in query_conf:
        try:
            result_dir.joinpath(filename).unlink()
        except FileNotFoundError:
            pass

    for map_name, G in zip(conf['maps'], graphs):
        nodes = random.sample(list(G.nodes), k=2 * conf['queries_per_setup'])
        for setup in conf['setups']:
            start_nodes = nodes[:int(len(nodes) / 2)]
            target_nodes = nodes[int(len(nodes) / 2):]
            insert_charging_stations(G, setup['charging_stations'], charging_stations)

            for func, filename, row_class in query_conf:
                logging.info('Running {} queries with {} on map {}'.format(
                    len(start_nodes),
                    func.__name__,
                    map_name
                ))
                with result_dir.joinpath(filename).open('a') as f:
                    write_head(f, row_class)
                    for s, t in zip(start_nodes, target_nodes):
                        write_row(f, func(G, setup, s, t))

                # Delete cached graphs
                for key in list(CACHE.keys()):
                    del CACHE[key]


def insert_charging_stations(graph, number, charging_stations):
    start = perf_counter()
    graph.insert_charging_stations(charging_stations, number)
    runtime = perf_counter() - start

    logging.info('Importing {} Charging Stations took {:.2f} s'.format(
        len(graph.charging_stations),
        runtime
    ))


def get_map(osm_path: Path, backup_dir=None):
    logging.info('Importing map {}'.format(osm_path.name))
    cache_path = backup_dir.joinpath(osm_path.with_suffix('.pck').name)
    try:
        with open(cache_path, 'rb') as f:
            graph = pickle.load(f)
        logging.info('Loaded map from cache {}'.format(cache_path))
    except FileNotFoundError:
        start = perf_counter()
        graph = read_osm(str(osm_path))
        runtime = perf_counter() - start
        logging.info('Importing map took {:.2f} s'.format(
            runtime
        ))
        with open(cache_path, 'wb') as f:
            pickle.dump(graph, f)
    else:
        graph.rebuild_rtree()

    return graph


def apply_conversions(conf):
    """kWh to Wh"""
    for setup in conf['setups']:
        setup['capacity'] = 1000 * setup['capacity']
        setup['consumption']['consumption_coefficient'] = 1000 * setup['consumption']['consumption_coefficient']
        setup['mu_s'] = 1000 * setup['mu_s']
        setup['mu_t'] = 1000 * setup['mu_t']
    return conf


if __name__ == '__main__':
    logging.basicConfig(
        format='%(asctime)s %(message)s',
        datefmt='%m/%d/%Y %I:%M:%S %p',
        level=logging.DEBUG)
    results_dir = base.joinpath('results')
    static_dir = base.joinpath('static')
    maps_dir = static_dir.joinpath('maps')
    maps_cache_dir = static_dir.joinpath('mapcache')

    parser = argparse.ArgumentParser(description='Run Benchmark Scripts.')
    parser.add_argument(
        '--configs',
        help='List of filenames to benchmark YAML configs in ./configs.',
        type=Path,
        nargs='+'
    )

    args = parser.parse_args()
    path: Path
    for path in args.configs:
        benchmark_dir = results_dir.joinpath(path.with_suffix(''))
        benchmark_dir.mkdir(exist_ok=True)

        path = path.with_suffix('.yaml')

        with base.joinpath('configs/', path).open() as f:
            conf = apply_conversions(yaml.load(f, Loader=yaml.Loader))

        graphs = [
            get_map(maps_dir.joinpath(m), maps_cache_dir)
            for m in conf['maps']
        ]

        if conf['type'] == 'query':
            query_dir = benchmark_dir.joinpath('queries')
            query_dir.mkdir(exist_ok=True)
            query_benchmark(graphs=graphs,
                            conf=conf,
                            result_dir=query_dir)