Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
main.py 18.15 KiB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 17 15:48:17 2023

@author: schoelleh96
"""

# %% library imports

# This will move the console to the right working directory.
from os.path import  dirname, abspath, exists
from os import chdir, makedirs
chdir(dirname(abspath(__file__)))

from sys import argv, path
import numpy as np
path.append("../wp21/pyscripts/LIB")
import plot as pp
import calc as cc
import data as dd
import matplotlib.pyplot as plt
from datetime import datetime
from matplotlib.widgets import (Slider, CheckButtons, RadioButtons,
                                Button, TextBox)
from warnings import catch_warnings, simplefilter

# %% Get the data

# only every nth trajectory
n = 1

force_calc = False

date_0 = datetime(2016,5,2,0)
fname = "../wp21/era5/traj/" + date_0.strftime('%Y/traj_%Y%m%d_%H') + ".npy"

B = dd.get_block_dat(date_0)
if date_0.strftime("%Y") == "2016":
    coasts = dd.get_coasts(10, 90, 0, -360)
elif date_0.strftime("%Y") == "2017":
    coasts = dd.get_coasts(10, 90, 120, -60)

#fname = "traj" + date_0.strftime('_%Y%m%d_%H') + ".npy"
path = "./" + date_0.strftime("dat_%Y%m%d_%H") + "/" + str(n) + "/"
if not exists(path):
    makedirs(path)

trajs = np.load(fname)

if argv[0]=='':
     direc = 'backward'
else:
    direc = argv[1]

if direc=="whole":
    t_sel = np.arange(0,trajs.shape[1])
    plot_0 = int(trajs.shape[1]/2)

elif direc=="forward":
    t_sel = np.arange(int(trajs.shape[1]/2),trajs.shape[1])
    plot_0 = 0
    B = list(B[12:,:,:], B[1], B[2])

elif direc=="backward":
    t_sel = np.arange(0,int(trajs.shape[1]/2)+1)
    plot_0 = int(trajs.shape[1]/2)

# t_sel = np.arange(73, 74)

# Full variables for boundary calculation
Lon = trajs['lon'][::n]
Lat = trajs['lat'][::n]
P = trajs['p'][::n] # is in hPa


# Variables at selected times for visualization and cs
lon = Lon[:, t_sel]
lat = Lat[:, t_sel]
p = P[:, t_sel] # is in hPa
u = trajs['U'][::n, t_sel] # is in m/s
v = trajs['V'][::n, t_sel] # is in m/s
omg = trajs['OMEGA'][::n, t_sel] # is in P/s
T = trajs['T'][::n, t_sel] # is in K

# %% Initials

azim = 150
elev = 45
dist = 7
t_i = 0
alpha = 1e-3
e = 20000
epsilon_opt = np.array([20000, 50000, 100000, 200000, 500000])
alph_opt = np.array([1e-4, 2e-4, 5e-4, 0.001, 0.002, 0.005, 0.01, 0.05,
                     0.1, 0.5, 1, 5, 10])
N_k = 6
# u and v is in m/s, omg in P/s; result is in km/hPa
# k_p = cc.calc_k(trajs['U']/1000, trajs['V']/1000, trajs['OMEGA']/100)
k_p = 15
# Coordinates for visualization
X = lon
Y = lat
Z = p
H = cc.ptoh(p, 1013.25, 8.435) # standard atmosphere
k_h = cc.calc_k(trajs['U'], trajs['V'], cc.omg2w(trajs['OMEGA']/100,
                                                 trajs['T'], trajs['p']))

# %% Boundary

# if False:
if argv[0] == '' or argv[1] == "Boundary":

    # Initialize

    fig_bound = plt.figure(figsize=(14,8))#, constrained_layout=True)
    ax3d = fig_bound.add_axes([0.2, 0.025, 0.8, 0.9], projection="3d")

    x, y, z = cc.coord_trans(Lon, Lat, P, "None", k_p, proj="stereo")

    # boundaries are always for whole time period
    bounds, hulls = dd.io_bounds(path + "stereop", "$\\alpha$", x,
                                 y, z, alpha, force_calc)

    ax3d.view_init(azim=azim, elev=elev)
    with catch_warnings():
        simplefilter("ignore")
        ax3d.dist=dist
    ax3d.scatter(x[:,t_i], y[:,t_i], z[:,t_i], c=bounds[:, t_i + t_sel[0]])
    ax3d.invert_zaxis()

    ax3d.set_xlabel("X")
    ax3d.set_ylabel("Y")
    ax3d.set_zlabel("p")

    # Create Interactives
    t_slider_ax  = fig_bound.add_axes([0.05, 0.95, 0.15, 0.025])
    t_slider = Slider(t_slider_ax, 'T', 0, lon.shape[1]-1, valinit=t_i,
                      valstep=1)
    dist_slider_ax  = fig_bound.add_axes([0.05, 0.92, 0.15, 0.025])
    dist_slider = Slider(dist_slider_ax, 'Zoom', 0, 10, valinit=dist)

    meth_ax = fig_bound.add_axes([0.02, 0.81, 0.15, 0.07])
    meth_check = RadioButtons(meth_ax, ["Convex", "$\\alpha$",
                                        "opt. $\\alpha$"], 1)
    meth_ax.set_title("Hull Method")

    alph_ax = fig_bound.add_axes([0.02, 0.64, 0.15, 0.13])
    alph_check = RadioButtons(alph_ax, alph_opt, 3)

    alph_ax.set_title("$\\alpha$")

    coord_ax = fig_bound.add_axes([0.02, 0.5, 0.15, 0.1])
    coord_check = RadioButtons(coord_ax, ["Lon, Lat, p", "Lon, Lat, h",
                                          "x, y, p (stereo)",
                                          "x, y, h (stereo)",
                                          "x, y, z (3d)"], 2)
    coord_ax.set_title("Coordinates")

    norm_ax = fig_bound.add_axes([0.02, 0.45, 0.15, 0.05])
    norm_check = CheckButtons(norm_ax, ["normalize?"], [False])

    run_ax = fig_bound.add_axes([0.02, 0.4, 0.15, 0.05])
    run_check = Button(run_ax, "Calculate")

    hull_ax = fig_bound.add_axes([0.02, 0.25, 0.15, 0.05])
    hull_check = CheckButtons(hull_ax, ["Plot hull?"], [False])

    vis_ax = fig_bound.add_axes([0.02, 0.1, 0.15, 0.1])
    vis_check = RadioButtons(vis_ax, ["Lon, Lat, p", "Lon, Lat, h",
                                      "x, y, p (stereo)",
                                      "x, y, h (stereo)",
                                      "x, y, z (3d)"], 2)
    vis_ax.set_title("Visualization")

    plot_ax = fig_bound.add_axes([0.02, 0.02, 0.15, 0.05])
    plot_check = Button(plot_ax, "Plot")

    # Make sliders interactive (no need to recalculate)
    def slider_changed(val):
        global X, Y, Z, bounds, hulls, t_sel

        azim_i = ax3d.azim
        elev_i = ax3d.elev
        ax3d.clear()

        ax3d.view_init(azim=azim_i, elev=elev_i)
        with catch_warnings():
            simplefilter("ignore")
            ax3d.dist=dist_slider.val

        ax3d.scatter(X[:,t_slider.val], Y[:,t_slider.val], Z[:,t_slider.val],
                     c=bounds[:, t_slider.val + t_sel[0]])
        ax3d.invert_zaxis()

        ax3d.set_xlabel(vis_check.value_selected.split(", ")[0])
        ax3d.set_ylabel(vis_check.value_selected.split(", ")[1])
        ax3d.set_zlabel(vis_check.value_selected.split(", ")[2].split(" ")[0])

        if hull_check.get_status()[0]:
            ax3d.plot_trisurf(hulls[t_slider.val + t_sel[0]].vertices[:, 0],
                              hulls[t_slider.val + t_sel[0]].vertices[:,1],
                              triangles=hulls[t_slider.val + t_sel[0]].faces,
                              Z=hulls[t_slider.val + t_sel[0]].vertices[:,2],
                              alpha=0.5)

    dist_slider.on_changed(slider_changed)
    t_slider.on_changed(slider_changed)

    # Make calculate button interactive
    def calc(val):
        global bounds, lon, lat, p, hulls, t_sel

        # Coordinates for boundary calculation
        if coord_check.value_selected == "Lon, Lat, p":
            x, y, z = Lon, Lat, P
            d_path = path + "lonlatp"
        elif coord_check.value_selected == "Lon, Lat, h":
            x, y = Lon, Lat
            z = cc.ptoh(P, 1013.25, 8.435)
            d_path = path + "lonlath"
        elif coord_check.value_selected == "x, y, p (stereo)":
            x, y, z = cc.coord_trans(Lon, Lat, P, "None", k_p, proj="stereo")
            d_path = path + "stereop"
        elif coord_check.value_selected == "x, y, h (stereo)":
            x, y, z = cc.coord_trans(Lon, Lat, P, "std_atm", k_h, proj="stereo")
            d_path = path + "stereoh"
        elif coord_check.value_selected == "x, y, z (3d)":
            x, y, z = cc.coord_trans(Lon, Lat, P, "std_atm", -1)
            d_path = path + "xyz"

        # Normalization
        if norm_check.get_status()[0]:
            x, y, z = cc.norm(x), cc.norm(y), cc.norm(z)
            d_path = d_path + "norm"

        # Boundary calculation
        bounds, hulls = dd.io_bounds(d_path, meth_check.value_selected, x, y,
                                     z, float(alph_check.value_selected),
                                     force_calc)

        print("Calculation Done")

        # If there is a time step, where all points are boundary
        if (bounds == 1).all(axis=0).any():
            allb = np.where((bounds == 1).all(axis=0))
            print("All points boundary at time {}; ".format(allb) +
                   "please normalize or adjust $\\alpha$")

    run_check.on_clicked(calc)

    # Make visualization button interactive
    def plot(val):
        global bounds, X, Y, Z, lon, lat, p, hulls

        if vis_check.value_selected == "Lon, Lat, p":
            X, Y, Z = lon, lat, p
        elif vis_check.value_selected == "Lon, Lat, h":
            X, Y = lon, lat
            Z = cc.ptoh(p, 1013.25, 8.435)
        elif vis_check.value_selected == "x, y, p (stereo)":
            X, Y, Z = cc.coord_trans(lon, lat, p, "None", k_p, proj="stereo")
        elif vis_check.value_selected == "x, y, h (stereo)":
            X, Y, Z = cc.coord_trans(lon, lat, p, "std_atm", k_h, proj="stereo")
        elif vis_check.value_selected == "x, y, z (3d)":
            X, Y, Z = cc.coord_trans(lon, lat, p, "std_atm", -1)

        # Normalization
        if norm_check.get_status()[0]:
            X, Y, Z = cc.norm(X), cc.norm(Y), cc.norm(Z)

        slider_changed(None)
        fig_bound.canvas.draw_idle()
        print("Plotting Done")

    plot_check.on_clicked(plot)

# %% Coherent Sets

# elif argv[0] == '' or argv[1] == "Sets":
if True:
    # Initialize
    fig_cs = plt.figure(figsize=(14,8))#, constrained_layout=True)
    ax3dcs = fig_cs.add_axes([0.2, 0.025, 0.8, 0.9], projection="3d")
    ax_spect = fig_cs.add_axes([0.03, 0.025, 0.3, 0.3])

    # lower case coordinates for boundary calculations
    xcs, ycs, zcs = cc.coord_trans(Lon, Lat, P, "None", proj="stereo",
                             scaling = k_p)
    # xcs, ycs, zcs = cc.norm(xcs), cc.norm(ycs), cc.norm(zcs)

    boundscs, hullscs = dd.io_bounds(
                path + "stereop", "$\\alpha$", xcs, ycs, zcs,
                alpha, False)

    bound_meth = "$\\alpha$"

    # D = dd.io_dist(path, e, lon, lat, p, True, k_p, "p", t_sel)

    D = list((list(), list(), list()))

    pp.plot_spectra(D, 3*np.sqrt(epsilon_opt), epsilon_opt, lon, lat, p, k_p,
                    ax_spect, boundscs, True, path + "spnalp" + str(alpha),
                    t_sel, "p")

    E = dd.io_eigen((path + "spnalp" + str(alpha) + str(t_sel[0]) +
                     str(t_sel[-1])),
                    lon.shape[0], t_sel[0], t_sel[-1], D[0], D[1], D[2], e,
                    boundscs, 20, lon, lat, p, True, k_p, "p",
                    force_calc=force_calc)

    kclust = cc.kcluster_idx(E, N_k)

    # upper case letters for visualization
    Xcs, Ycs, Zcs = cc.coord_trans(lon, lat, p, "None", 1, proj="stereo")
    points = pp.plot_clustering(ax3dcs, Xcs, Ycs, Zcs, azim, elev, dist, t_i,
                                c=kclust.labels_, plot_0=False,
                                coord="stereo", mean_traj=True, block=None,
                                coast=None)
    ax3dcs.invert_zaxis()

    # Add widget items
    t_slider_ax2  = fig_cs.add_axes([0.06, 0.95, 0.2, 0.025])
    t_slider2 = Slider(t_slider_ax2, 'T', 0, lon.shape[1]-1, valinit=t_i,
                       valstep=1)
    dist_slider_ax2  = fig_cs.add_axes([0.06, 0.92, 0.2, 0.025])
    dist_slider2 = Slider(dist_slider_ax2, 'Zoom', 0, 10, valinit=dist)
    sc_ax = fig_cs.add_axes([0.02, 0.8, 0.1, 0.06])
    sc_check = RadioButtons(sc_ax, ["$k_p$", "$k_h$"], 0)
    sc_ax.set_title("Vertical Scaling")
    bound_ax = fig_cs.add_axes([0.02, 0.7, 0.1, 0.06])
    bound_check = RadioButtons(bound_ax, ["Concave", "Convex", "$\\alpha$"], 2)
    bound_ax.set_title("Hull Method")
    eps_ax = fig_cs.add_axes([0.02, 0.56, 0.1, 0.1])
    eps_check = RadioButtons(eps_ax, epsilon_opt, 2)
    eps_ax.set_title("$\epsilon$")
    N_k_ax = fig_cs.add_axes([0.02, 0.5, 0.1, 0.05])
    N_k_box = TextBox(N_k_ax, "$N_k$")
    N_k_box.set_val("3")
    N_v_ax = fig_cs.add_axes([0.06, 0.5, 0.1, 0.05])
    N_v_box = TextBox(N_v_ax, "$N_v$")
    N_v_box.set_val("2")
    null_ax = fig_cs.add_axes([0.02, 0.43, 0.1, 0.05])
    null_check = CheckButtons(null_ax, ["Plot start?"], [False])
    geo_ax = fig_cs.add_axes([0.1, 0.43, 0.1, 0.05])
    geo_check = CheckButtons(geo_ax, ["Plot features?"], [False])
    run_ax = fig_cs.add_axes([0.02, 0.35, 0.1, 0.05])
    run_check2 = Button(run_ax, "Calculate")

    # Make sliders functional
    def slider_changed_cs(val):
        global dat, v, points, colors, vals, inv, Xsc, Ysc, Zsc
        azim_i = ax3dcs.azim
        elev_i = ax3dcs.elev
        ax3dcs.clear()

        if geo_check.get_status()[0]:
            block = B
            coast = coasts
        else:
            block = None
            coast = None

        if null_check.get_status()[0]:
            points = pp.plot_clustering(ax3dcs, Xcs, Ycs, Zcs, azim_i,
                                        elev_i,
                                        dist_slider2.val, t_slider2.val,
                                        c=kclust.labels_, plot_0=plot_0,
                                        coord="stereo", mean_traj=True,
                                        block=block, coast=coast)
        else:
            points = pp.plot_clustering(ax3dcs, Xcs, Ycs, Zcs, azim_i, elev_i,
                                        dist_slider2.val, t_slider2.val,
                                        c=kclust.labels_, plot_0=False,
                                        coord="stereo", mean_traj=True,
                                        block=block, coast=coast)
        ax3dcs.invert_zaxis()

        for c in np.unique(kclust.labels_):
            colors[c] = points.to_rgba(c)

        keys, inv = np.unique(kclust.labels_, return_inverse=True)
        vals = np.array([colors[key] for key in keys])

        var_changed(None)

        fig_cs.canvas.draw_idle()

    dist_slider2.on_changed(slider_changed_cs)
    t_slider2.on_changed(slider_changed_cs)
    null_check.on_clicked(slider_changed_cs)
    geo_check.on_clicked(slider_changed_cs)


    # Make N_k box functional
    def N_k_changed(val):
        global E, kclust
        N_v_box.set_val(str(int(val)-1))
        kclust = cc.kcluster_idx(E, int(val), int(N_v_box.text))
        slider_changed_cs(None)

    N_k_box.on_submit(N_k_changed)

    # Make N_k box functional
    def N_v_changed(val):
        global E, kclust
        kclust = cc.kcluster_idx(E, int(N_k_box.text), int(val))
        slider_changed_cs(None)

    N_v_box.on_submit(N_v_changed)

    # Make Calculation Button functional
    def calc_cs(val):
        global boundscs, hulls, D, E, kclust, Xcs, Ycs, Zcs, colors, vals
        global bound_meth

        if ((bound_check.value_selected == "Concave") &
            (bound_meth != "Concave")):
            boundscs, hullscs = dd.io_bounds(path + "stereop",# + "norm" ,
                                         "opt. $\\alpha$", xcs, ycs, zcs, None)
            bound_meth = "Concave"

        elif ((bound_check.value_selected == "Convex") &
              (bound_meth != "Convex")):
            boundscs, hullscs = dd.io_bounds(path + "stereop",#  + "norm" ,
                                             "Convex", xcs, ycs, zcs, None)
            bound_meth = "Convex"

        elif ((bound_check.value_selected == ("$\\alpha$")) &
              (bound_meth != "$\\alpha$")):
            boundscs, hullscs = dd.io_bounds(
                path + "stereop", "$\\alpha$", xcs, ycs, zcs,
                float(alph_check.value_selected))
            bound_meth = "$\\alpha$"

        if sc_check.value_selected == "$k_p$":
            # D = dd.io_dist(path, eps_check.value_selected, lon, lat, p,
            #                True, k_p, "p", t_sel)
            if bound_meth=="Concave":
                d_path = path + "spnopt"
            elif bound_meth=="Convex":
                d_path = path + "spncvx"
            elif bound_meth=="$\\alpha$":
                d_path = path + "spnalp" + str(alph_check.value_selected)
            pp.plot_spectra(D, 3*np.sqrt(epsilon_opt), epsilon_opt, lon,
                                lat, p, k_p, ax_spect, boundscs, True, d_path,
                                t_sel, "p")
            E = dd.io_eigen(d_path + str(t_sel[0]) + str(t_sel[-1]),
                            lon.shape[0], t_sel[0], t_sel[-1],
                            D[0], D[1], D[2], eps_check.value_selected,
                            boundscs, 20, lon, lat, p, True, k_p, "p")

        elif sc_check.value_selected == "$k_h$":
            # D = dd.io_dist(path, eps_check.value_selected, lon, lat, H,
            #                True, k_h, "H", t_sel)
            if bound_meth=="Concave":
                d_path = path + "shnopt"
            elif bound_meth=="Convex":
                d_path = path + "shncvx"
            elif bound_meth=="$\\alpha$":
                d_path = path + "shnalp"
            pp.plot_spectra(D, 3*np.sqrt(epsilon_opt), epsilon_opt, lon, lat,
                            H, k_h, ax_spect, boundscs, True, d_path, t_sel,
                            "H")
            E = dd.io_eigen(d_path + str(t_sel[0]) + str(t_sel[-1]),
                            lon.shape[0], t_sel[0], t_sel[-1],
                            D[0], D[1], D[2], eps_check.value_selected,
                            boundscs, 20, lon, lat, H, True, k_h, "H")

        N_k_changed(N_k_box.text)

    run_check2.on_clicked(calc_cs)

     # %% spaghetti

    alpha_spag = 0.1 + (n/100)

    fig_spag = plt.figure()
    ax_spag = fig_spag.add_subplot(111)
    fig_spag.subplots_adjust(left=0.2)
    var_ax = fig_spag.add_axes([0.025, 0.1, 0.1, 0.8])
    var_check = RadioButtons(var_ax, trajs.dtype.names, 6)
    var_ax.set_title("Variable")
    dat = trajs[var_check.value_selected][::n, t_sel]

    colors = {}
    for c in np.unique(kclust.labels_):
        colors[c] = points.to_rgba(c)

    keys, inv = np.unique(kclust.labels_, return_inverse=True)
    vals = np.array([colors[key] for key in keys])
    # to remove Null-Level-Set cluster (assuming its cluster 5):
    # dat = dat[inv!=5, :]
    # inv = inv[inv !=5]
    pp.plot_spaghetti(ax_spag, dat, alpha_spag, colors=colors, inv=inv)

    v = pp.draw_vert_t(ax_spag, t_sel[0]-72, t_sel[-1] -72, t_slider2.val -72)
    plt.show()

    def var_changed(val):
        dat = trajs[var_check.value_selected][::n, t_sel]
        ax_spag.clear()
        pp.plot_spaghetti(ax_spag, dat, alpha_spag, colors=colors, inv=inv)
        pp.draw_vert_t(ax_spag, t_sel[0]-72, t_sel[-1] -72, t_slider2.val -72)
        fig_spag.canvas.draw_idle()

    var_check.on_clicked(var_changed)