Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
system.py 1.11 KiB
"""
This module contains base class for a n-body gravitational system.
"""

from collections.abc import Callable
from dataclasses import dataclass

import numpy as np


@dataclass
class GravitationalSystem:
    r0: np.ndarray[float]  # shape = (N,d)
    v0: np.ndarray[float]  # shape = (N,d)
    m: np.ndarray[float]  # shape = N
    t: np.ndarray[float]  # shape = n_time_steps
    force: Callable
    solver: Callable

    def __post_init__(self):
        "Checking dimensions of inputs"
        assert self.r0.shape == self.v0.shape
        assert self.m.shape[0] == self.r0.shape[0]

        solvers = ['verlet', 'bht']
        assert self.solver.__name__ in solvers

    def simulation(self):
        """
        Using integrator to compute trajectory in phase space
        """
        p = np.zeros((len(self.t), *self.v0.shape))
        q = np.zeros((len(self.t), *self.r0.shape))
        q[0] = self.r0
        p[0] = self.m * self.v0

        for i in range(1, len(self.t)):
            dt = self.t[i] - self.t[i - 1]
            p[i], q[i] = self.solver(self.force, q[i - 1], p[i - 1], self.m, dt)

        return self.t, p, q