Skip to content
Snippets Groups Projects
visualization.py 915 B
Newer Older
nguyed99's avatar
nguyed99 committed
import numpy as np
import matplotlib.pyplot as plt


def plot_trajectory(traj: np.ndarray, mass: list, update_interval=10):
    """
    Args:
    - traj: trajectory with shape (no_time_steps, no_bodies * dim)
    - mass: masses of bodies
    """

    n_time_step, n_bodies = traj.shape  #
    n_bodies //= 3

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    for i in range(0, n_time_step, update_interval):
        ax.clear()
        for j in range(n_bodies):
            start_idx = j * 3
            end_idx = (j + 1) * 3

            body_traj = traj[i, start_idx:end_idx]
            ax.scatter(body_traj[0], body_traj[1], body_traj[2], s=10 * mass[j] / min(mass))

        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_zlabel("Z")
        ax.set_title(f"Time step: {i}")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])

        plt.pause(0.01)