import numpy as np import matplotlib.pyplot as plt def plot_trajectory(traj: np.ndarray, mass: list, update_interval=10): """ Args: - traj: trajectory with shape (no_time_steps, no_bodies * dim) - mass: masses of bodies """ n_time_step, n_bodies = traj.shape # n_bodies //= 3 fig = plt.figure() ax = fig.add_subplot(projection='3d') for i in range(0, n_time_step, update_interval): ax.clear() for j in range(n_bodies): start_idx = j * 3 end_idx = (j + 1) * 3 body_traj = traj[i, start_idx:end_idx] ax.scatter(body_traj[0], body_traj[1], body_traj[2], s=10 * mass[j] / min(mass)) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_zlabel("Z") ax.set_title(f"Time step: {i}") ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) plt.pause(0.01)