import numpy as np
from matplotlib import pyplot as plt

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": "Helvetica",
})

from tqdm import tqdm

# a - Implementation based on Leimkuhler & Matthews
gamma = 1
omega_0 = 1/3 * gamma
m = 1
d = 3
k_BT = 1
l = 1 # gamma**-1 * np.sqrt(k_BT / m)
N = 100
t_max = 1000
dt = 0.1
t = np.arange(0,t_max,dt)

def force(r: np.ndarray) -> np.ndarray:
    return - m * omega_0**2 * r

def BAOAB(force, x0, p0, m, dt, N):
    assert(x0.shape == p0.shape)

    x = np.zeros((len(t), *x0.shape))
    p = np.zeros((len(t), *p0.shape))
    x[0] = x0
    p[0] = p0
    xi_2 = np.sqrt(k_BT*(1 - np.exp(-2 * gamma * dt)))

    for i in range(1, len(t)):
        r = np.random.normal(0,1,(N,d))
        p[i] = p[i-1] + 1/2 * force(x[i-1]) * dt # same
        x[i] = x[i-1] + p[i] / (2*m) * dt # diff
        p[i] = np.exp(-gamma*dt) * p[i] + xi_2 * r * np.sqrt(m)
        x[i] = x[i] + (dt/2) * p[i]/m
        p[i] = p[i] + 1/2 * force(x[i]) * dt

    return x, p

r0 = np.zeros((N, d))
p0 = np.zeros((N, d))

r, p = BAOAB(force=force, x0=r0, p0=p0, m=m, dt=dt, N=N)

# b
# # kin and pot

E_kin = np.sum(p**2, axis=(1,2)) / (2*m*N)
E_pot = (m/(2*N)) * omega_0**2 * np.linalg.norm(r, axis=(1,2))**2

# plt.figure()
# plt.plot(E_kin, label=r'$E_{kin}$')
# plt.plot(E_pot, label=r'$E_{pot}$')
# plt.plot(np.full(fill_value=np.average(E_kin),shape=len(t)), label=r'$E_{kin\_av}$', linestyle='--')
# plt.plot(np.full(fill_value=np.average(E_pot),shape=len(t)), label=r'$E_{kin\_pot}$', linestyle='dashed')
# plt.vlines(50, min(np.min(E_kin), np.min(E_pot)), max(np.max(E_kin), np.max(E_pot)), linestyles='dashed', colors='purple')
# plt.xlabel(r'$t / \gamma^{-1}$')
# plt.ylabel(r'$E / k_BT$')
# plt.grid(True)
# plt.legend(loc='lower right')
# plt.tight_layout()
# plt.savefig('10.1_b.png')

# # 3D-gauss dist
# r_x = r[51:,:,0].flatten()
# r_y = r[51:,:,1].flatten()
# r_z = r[51:,:,2].flatten()
# p_x = p[51:,:,0].flatten()
# p_y = p[51:,:,1].flatten()
# p_z = p[51:,:,2].flatten()


# fig = plt.figure()
# ax = fig.add_subplot(projection='3d')
# ax.scatter(r_x, r_y, r_z, s=10, facecolors='none', edgecolors='r')
# plt.savefig('norm_dist_x.png')
# coords = ['x','y','z']
# for i in range(3):
#     r_i =r[51:,:,i].flatten()
#     hist_r, r_range = np.histogram(r_i, bins = len(r_i),density=True)
#     plt.figure()
#     plt.plot(hist_r, label=f'r_{coords[i]}')
#     plt.legend()
#     plt.tight_layout()
#     plt.savefig(f'r_{coords[i]}_dist.png')

#     p_i =r[51:,:,i].flatten()
#     hist_p, p_range = np.histogram(p_i, bins = len(p_i),density=True)
#     plt.figure()
#     plt.plot(hist_p, label=f'p_{coords[i]}')
#     plt.legend()
#     plt.tight_layout()
#     plt.savefig(f'p_{coords[i]}_dist.png')

# c)
r_f, p_f = BAOAB(lambda r: 0, r0, p0, m, dt, N)

dr2 = np.zeros(t.shape)
dr2_f = np.zeros(t.shape)
for i in tqdm(range(1, len(dr2))):
    dr2[i] = np.average(np.linalg.norm(r[i:] - r[:-i], axis=2)**2)
    dr2_f[i] = np.average(np.linalg.norm(r_f[i:] - r_f[:-i], axis=2)**2)

plt.loglog(t, dr2, label="harmonic")
plt.loglog(t, dr2_f, label="free")
plt.loglog(t, 3 * k_BT / m * t**2, ls="--", label=r"$3 k_B T t^2 / m$")
plt.loglog(t, 6 * k_BT * t / (m * gamma) - 6 * k_BT / (m * gamma**2) * (1 - np.exp(-gamma * t)), ls="-.", label="free, analytic")
plt.legend()
plt.xlabel(r"$\log t$")
plt.ylabel(r"$\log \delta r^2$")
plt.savefig("problem10.1c.png", dpi=300)
plt.show()