import numpy as np
import matplotlib.pyplot as plt

def iterations(h5file, FINAL_TIME, interval = []):

    # read time
    time = np.array(h5file['relativeTime']) * FINAL_TIME
    time = np.delete(time, 0)
    if len(interval) == 0:
        interval = [0, FINAL_TIME]
    t = [i for i in range(len(time)) if time[i]>=interval[0] and time[i]<=interval[1]]
    time = time[t]

    tau = np.array(h5file['relativeTimeIncrement']) * FINAL_TIME
    tau = np.delete(tau, 0)

    # read fpi iterations
    fpi_final = np.array(h5file['iterations/fixedPoint/final'])
    fpi_final = np.delete(fpi_final, 0)
    #fpi_total = np.array(h5file['iterations/fixedPoint/total'])

    print("FPI average: " + str(np.average(fpi_final)))
    print("FPI max: " + str(np.max(fpi_final)))

    # read multigrid iterations
    multigrid_final = np.array(h5file['iterations/multiGrid/final'])
    multigrid_final = np.delete(multigrid_final, 0)
    #multigrid_total = np.array(h5file['iterations/multiGrid/total'])

     # plot
    fig = plt.figure()

    ax_fpi = fig.add_subplot(3, 1, 1)
    ax_fpi.plot(time, fpi_final[t], color='black', linestyle='-')
    ax_fpi.set_ylabel('fpi')
    #-------------------------

    ax_mg = fig.add_subplot(3, 1, 2)
    ax_mg.plot(time, multigrid_final[t], color='black', linestyle='-')
    ax_mg.set_ylabel('multigrid iter.')
    ax_mg.set_xlabel('time [s]')
    #-------------------------

    tau_t = tau[t]
    ax_tau = fig.add_subplot(3, 1, 3)
    ax_tau.plot(time, tau_t, color='black', linestyle='-')
    ax_tau.set_ylabel('tau')
    ax_tau.set_yscale('log')

    print("tau_max / tau_min: " + str(np.max(tau_t)/np.min(tau_t)))
    #-------------------------

    fig.canvas.draw()