Newer
Older
import numpy as np
import matplotlib.pyplot as plt
def plot_trajectory(traj: np.ndarray, mass: list, update_interval=10):
"""
Args:
- traj: trajectory with shape (no_time_steps, no_bodies * dim)
- mass: masses of bodies
"""
n_time_step, n_bodies = traj.shape #
n_bodies //= 3
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
for i in range(0, n_time_step, update_interval):
ax.clear()
for j in range(n_bodies):
start_idx = j * 3
end_idx = (j + 1) * 3
body_traj = traj[i, start_idx:end_idx]
ax.scatter(body_traj[0], body_traj[1], body_traj[2], s=10 * mass[j] / min(mass))
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.set_title(f"Time step: {i}")
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
plt.pause(0.01)