import networkx as nx
from copy import copy
from typing import Set

from evrouting.T import Node, ChargingCoefficient
from evrouting.graph_tools import (
    node_convert, edge_convert, TemplateEdge, TemplateNode, charging_cofficient
)

# List of configs
config_list = [
    'edge_case',
    'edge_case_start_node_no_cs',
    'edge_case_a_slow'
]

edge_case = {
    'beta_s': 0,
    'beta_t': 0,
    'U': 4,
    's': 0,
    't': 2,
    'nodes': [
        TemplateNode('s', charging_coeff=1),
        TemplateNode('a', charging_coeff=2),
        TemplateNode('t'),
    ],
    'edges': [
        TemplateEdge(0, 1, distance=1, consumption=1),
        TemplateEdge(0, 2, distance=1, consumption=4),
        TemplateEdge(1, 2, distance=1, consumption=1),
    ]
}

edge_case_a_slow = {
    'beta_s': 0,
    'beta_t': 0,
    'U': 4,
    's': 0,
    't': 2,
    'nodes': [
        TemplateNode('s', charging_coeff=2),
        TemplateNode('a', charging_coeff=1),
        TemplateNode('t'),
    ],
    'edges': [
        TemplateEdge(0, 1, distance=1, consumption=1),
        TemplateEdge(0, 2, distance=1, consumption=4),
        TemplateEdge(1, 2, distance=1, consumption=1),
    ]
}

edge_case_start_node_no_cs = {
    'beta_s': 0,
    'beta_t': 0,
    'U': 4,
    's': 0,
    't': 2,
    'nodes': [
        TemplateNode('s'),
        TemplateNode('a', charging_coeff=2),
        TemplateNode('t'),
    ],
    'edges': [
        TemplateEdge(0, 1, distance=1, consumption=1),
        TemplateEdge(0, 2, distance=1, consumption=4),
        TemplateEdge(1, 2, distance=1, consumption=1),
    ]
}


def get_graph(config: dict) -> nx.Graph:
    G = nx.Graph()

    for node_id, node in enumerate(config['nodes']):
        G.add_node(node_id, **node_convert(node))

    for edge in config['edges']:
        G.add_edge(edge.u, edge.v, **edge_convert(edge))

    return G


def get_charging_stations(config: dict) -> Set[Node]:
    return {
        idx for idx, n in enumerate(config['nodes'])
        if n.charging_coeff is not None
    }


def init_config(config: dict) -> dict:
    G = nx.Graph()
    S = set()

    for node_id, node in enumerate(config['nodes']):
        G.add_node(node_id, **node_convert(node))
        c: ChargingCoefficient = charging_cofficient(G, node_id)
        if c is not None:
            S.add(node_id)

    for edge in config['edges']:
        G.add_edge(edge.u, edge.v, **edge_convert(edge))

    return {
        'G': G,
        'charging_stations': S,
        's': config['s'],
        't': config['t'],
        'initial_soc': config['beta_s'],
        'final_soc': config['beta_t'],
        'capacity': config['U']
    }