import logging
import random
from typing import TextIO, Union, Type
from time import perf_counter
from pathlib import Path
from dataclasses import asdict, fields

import networkx as nx
from networkx.algorithms.shortest_paths.weighted import _weight_function
from evrouting.graph_tools import DISTANCE_KEY
from evaluation.lib.algorithm import _dijkstra_multisource
from evaluation.lib import queries
from evaluation.lib.config import RankConf, QueryConf, InitConf, AnyConf

logger = logging.getLogger(__name__)

conf_algorithms = {
    'classic': queries.classic_query,
    'astar': queries.astar_query,
    'bidirectional': queries.bidirectional_query,
    'charge': queries.charge_query,
    'gasstation': queries.gasstation_query
}

QueryRow = Union[
    queries.InsertQueryRow,
    queries.InitQueryRow,
    queries.ClassicQueryRow,
    queries.ChargeQueryRow,
    queries.GasstationQueryRow,
    queries.AStarQueryRow,
    queries.QueryRow
]

SEP = ','


def write_head(f: TextIO, row_class: Type[QueryRow]):
    head = SEP.join([field.name for field in fields(row_class)])
    f.write(head + '\n')


def write_row(f: TextIO, row: QueryRow):
    f.write(SEP.join([str(i) for i in asdict(row).values()]) + "\n")


def fname(algorithm_name: str) -> str:
    return f'{algorithm_name}.csv'


def _insert_charging_stations(G, charging_stations, number=None):
    start = perf_counter()
    G.insert_charging_stations(charging_stations, number)
    runtime = perf_counter() - start

    logger.info('Importing {} Charging Stations took {:.2f} s'.format(
        len(G.charging_stations),
        runtime
    ))


def _init_result_files(result_dir, conf: AnyConf):
    files = []
    if type(conf) == InitConf:
        files = [
            (queries.InitQueryRow, 'init.csv'),
            (queries.InsertQueryRow, 'insert.csv')
        ]
    elif type(conf) in [RankConf, QueryConf]:
        algorithm_map = {
            'classic': queries.ClassicQueryRow,
            'astar': queries.AStarQueryRow,
            'bidirectional': queries.QueryRow,
            'gassation': queries.GasstationQueryRow,
            'charge': queries.ChargeQueryRow,
        }
        files = [(algorithm_map[alg], fname(alg)) for alg in conf.algorithms]

    # Remove existing results and write heads
    for row_class, filename in files:
        with result_dir.joinpath(filename).open('w') as f:
            write_head(f, row_class)


def _run_queries(func, start_nodes, target_nodes, file: TextIO, **kwargs):
    logger.info(f'Running {len(start_nodes)} times {func.__name__}..')
    num_total = len(start_nodes)
    for i, (s, t) in enumerate(zip(start_nodes, target_nodes)):
        logger.debug(f'{i + 1}/{num_total}')
        result_data = func(s=s, t=t, **kwargs)
        write_row(file, result_data)

    # Delete cached graphs
    for key in list(queries.CACHE.keys()):
        del queries.CACHE[key]

    logger.debug(f'Queries completed.')


def _get_target_with_rank(graph, start_node, rank):
    weight = _weight_function(graph, DISTANCE_KEY)
    return _dijkstra_multisource(
        graph,
        [start_node],
        weight=weight,
        rank=rank
    )


def _get_ranked_tasks(G, rank, number):
    """
    Generate <number> start and target nodes with Dijkstra Rank of <rank>.

    This is done by randomly sampling start nodes and finding for each of
    them a target nodes with the according rank by executing a modified
    Dijkstra routine.

    """
    target_nodes = []
    start_nodes = []
    attempts = 0
    # Try three times to find a tasks with the required rank.
    # If not enough nodes can be found, return what you have.
    while len(target_nodes) < number and attempts < 3:
        attempts += 1
        for s in random.sample(list(G.nodes), number):
            try:
                target_nodes.append(_get_target_with_rank(G, s, rank))
            except nx.NetworkXNoPath:
                continue
            start_nodes.append(s)
            if len(target_nodes) == number:
                break

    return start_nodes, target_nodes


def query(G, charging_stations, conf: QueryConf, result_dir):
    _init_result_files(result_dir, conf)
    for n_cs in conf.charging_stations:
        # Random start and target nodes
        nodes = random.sample(list(G.nodes), k=2 * conf.queries_per_row)
        start_nodes = nodes[:int(len(nodes) / 2)]
        target_nodes = nodes[int(len(nodes) / 2):]

        # Random adding of n_cs charging stations
        _insert_charging_stations(G, charging_stations, number=n_cs)

        for alg in conf.algorithms:
            func = conf_algorithms[alg]
            filename = result_dir.joinpath(fname(alg))
            with filename.open('a') as file:
                _run_queries(func,
                             start_nodes,
                             target_nodes,
                             file,
                             G=G,
                             conf=conf)


def rank(G, charging_stations, conf: RankConf, result_dir: Path):
    _init_result_files(result_dir, conf)
    _insert_charging_stations(G, charging_stations)  # Add all charging stations

    for r in conf.ranks:
        logger.debug(f'Getting {conf.queries_per_rank} targets with rank {r}.')
        start = perf_counter()
        start_nodes, target_nodes = _get_ranked_tasks(G, r, conf.queries_per_rank)
        end = perf_counter() - start
        logger.debug(f'Ranked nodes generated in {end:.2f} s')

        for alg in conf.algorithms:
            func = conf_algorithms[alg]
            filename = result_dir.joinpath(fname(alg))
            with filename.open('a') as file:
                _run_queries(func,
                             start_nodes,
                             target_nodes,
                             file=file,
                             G=G,
                             conf=conf)


def init(G, charging_stations, conf: InitConf, result_dir: Path):
    _init_result_files(result_dir, conf)

    # Nodes for insertion
    nodes = random.sample(list(G.nodes), k=2 * conf.queries_per_row)
    start_nodes = nodes[:int(len(nodes) / 2)]
    target_nodes = nodes[int(len(nodes) / 2):]

    result_init = result_dir.joinpath('init.csv')
    result_insert = result_dir.joinpath('insert.csv')
    with result_init.open('a') as out_init, result_insert.open('a') as out_insert:
        for n_cs in conf.charging_stations:
            # Random adding of charging stations
            _insert_charging_stations(G, charging_stations, n_cs)
            result_data = queries.init_gasstation_queries(G, conf)
            write_row(out_init, result_data)

            # Make insertion
            _run_queries(
                queries.insert_nodes_into_state_graph,
                start_nodes,
                target_nodes,
                out_insert,
                G=G,
                conf=conf,
            )