diff --git a/jobs/src/bht_algorithm.py b/jobs/src/bht_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..ea92574ad082ce79d544e9b1ba9f8e79fda061a5 --- /dev/null +++ b/jobs/src/bht_algorithm.py @@ -0,0 +1,321 @@ +""" +This module implements the Barnes Hut Tree Algorithm +""" + +import numpy as np + + +class MainApp: + + def __init__(self): + """Initialize the MainApp with a root node.""" + self.rootNode = TreeNode(x=-1, y=-1, width=2, height=2) + + def BuildTree(self, particles): + """Build the Quadtree by inserting particles. + + Args: + particles (list): A list of Particle objects to be inserted into the Quadtree. + """ + self.ResetTree() # Empty the tree + for particle in particles: + self.rootNode.insert(particle) + + def ResetTree(self): + """Reset the Quadtree by reinitializing the root node.""" + self.rootNode = TreeNode(x=-1, y=-1, width=2, height=2) + + +class Particle: + + def __init__(self, x, y, mass): + """Initialize a Particle with x and y coordinates and mass.""" + self.x = x + self.y = y + self.mass = mass + self.vx = 0.0 # Velocity component in x direction + self.vy = 0.0 # Velocity component in y direction + self.fx = 0.0 # Force component in x direction + self.fy = 0.0 # Force component in y direction + + +class TreeNode: + + def __init__(self, x, y, width, height): + """Initialize a TreeNode representing a quadrant in the Quadtree. + + Args: + x (float): x-coordinate of the node. + y (float): y-coordinate of the node. + width (float): Width of the node. + height (float): Height of the node. + """ + self.x = x + self.y = y + self.width = width + self.height = height + self.particle = None # Particle contained in this node + self.center_of_mass = np.array([(x + width) / 2, (y + width) / 2]) + self.total_mass = 0 + self.children = np.empty(4, dtype=object) # Children nodes (SW, SE, NW, NE) + + def contains(self, particle): + """Check if the particle is within the bounds of this node. + + Args: + particle (Particle): The particle to be checked. + + Returns: + bool: True if the particle is within the node's bounds, False otherwise. + """ + return (self.x <= particle.x < self.x + self.width and self.y <= particle.y < self.y + self.height) + + def insert(self, particle): + """Insert a particle into the Quadtree. + + Args: + particle (Particle): The particle to be inserted. + + Returns: + bool: True if the particle is inserted, False otherwise. + """ + if not self.contains(particle): + return False # Particle doesn't belong in this node + + if self.particle is None and all(child is None for child in self.children): + # If the node is empty and has no children, insert particle here. + self.particle = particle + #print(f'particle inserted: x={round(self.particle.x,2)}, y={round(self.particle.y,2)}') + return True # Particle inserted in an empty node + + if all(child is None for child in self.children): + # If no children exist, create and insert both particles + self.subdivide() + self.insert(self.particle) # Reinsert existing particle + self.insert(particle) # Insert new particle + self.particle = None # Clear particle from this node + else: + # If the node has children, insert particle in the child node. + quad_index = self.get_quadrant(particle) + if self.children[quad_index] is None: + # Create a child node if it doesn't exist + self.children[quad_index] = TreeNode(self.x + (quad_index % 2) * (self.width / 2), + self.y + (quad_index // 2) * (self.height / 2), self.width / 2, + self.height / 2) + self.children[quad_index].insert(particle) + + def subdivide(self): + """Subdivide the node into four quadrants.""" + sub_width = self.width / 2 + sub_height = self.height / 2 + self.children[0] = TreeNode(self.x, self.y, sub_width, sub_height) # SW + self.children[1] = TreeNode(self.x + sub_width, self.y, sub_width, sub_height) # SE + self.children[2] = TreeNode(self.x, self.y + sub_height, sub_width, sub_height) # NW + self.children[3] = TreeNode(self.x + sub_width, self.y + sub_height, sub_width, sub_height) # NE + + def get_quadrant(self, particle): + """Determine the quadrant index for a particle based on its position. + + Args: + particle (Particle): The particle to determine the quadrant index for. + + Returns: + int: Quadrant index (0 for SW, 1 for SE, 2 for NW, 3 for NE). + """ + mid_x = self.x + self.width / 2 + mid_y = self.y + self.height / 2 + quad_index = (particle.x >= mid_x) + 2 * (particle.y >= mid_y) + return quad_index + + def print_tree(self, depth=0): + """Print the structure of the Quadtree. + + Args: + depth (int): Current depth in the tree (for indentation in print). + """ + if self.particle: + print( + f"{' ' * depth}Particle at ({round(self.particle.x,2)}, {round(self.particle.y,2)}) in Node ({self.x}, {self.y}), size={self.width}" + ) + else: + print(f"{' ' * depth}Node ({self.x}, {self.y}) - Width: {self.width}, Height: {self.height}") + for child in self.children: + if child: + child.print_tree(depth + 2) + + def ComputeMassDistribution(self): + """Compute the mass distribution for the tree nodes. + + This function calculates the total mass and the center of mass + for each node in the Quadtree. It's a recursive function that + computes the mass distribution starting from the current node. + + Note: + This method modifies the 'mass' and 'center_of_mass' attributes + for each node in the Quadtree. + + Returns: + None + """ + if self.particle is not None: + # Node contains only one particle + self.center_of_mass = np.array([self.particle.x, self.particle.y]) + self.total_mass = self.particle.mass + else: + # Multiple particles in node + total_mass = 0 + center_of_mass_accumulator = np.array([0.0, 0.0]) + + for child in self.children: + if child is not None: + # Recursively compute mass distribution for child nodes + child.ComputeMassDistribution() + total_mass += child.total_mass + center_of_mass_accumulator += child.total_mass * child.center_of_mass + + if total_mass > 0: + self.center_of_mass = center_of_mass_accumulator / total_mass + self.total_mass = total_mass + else: + # If total mass is 0 or no child nodes have mass, leave values as default + pass + #self.center_of_mass = np.array([(x+width)/2, (y+width)/2]) + #self.total_mass = 0 + + def CalculateForceFromTree(self, target_particle, theta=1.0): + """Calculate the force on a target particle using the Barnes-Hut algorithm. + + Args: + target_particle (Particle): The particle for which the force is calculated. + theta (float): The Barnes-Hut criterion for force approximation. + + Returns: + np.ndarray: The total force acting on the target particle. + """ + + total_force = np.array([0.0, 0.0]) + + if self.particle is not None: + # Node contains only one particle + if self.particle != target_particle: + # Calculate gravitational force between target_particle and node's particle + force = self.GravitationalForce(target_particle, self.particle) + total_force += force + else: + if self.total_mass == 0: + return total_force + + r = np.linalg.norm(np.array([target_particle.x, target_particle.y]) - self.center_of_mass) + d = max(self.width, self.height) + + if d / r < theta: + # Calculate gravitational force between target_particle and "node particle" representing cluster + node_particle = Particle(self.center_of_mass[0], self.center_of_mass[1], self.total_mass) + force = self.GravitationalForce(target_particle, node_particle) + total_force += force + else: + for child in self.children: + if child is not None: + # Recursively calculate force from child nodes + if target_particle is not None: # Check if the target_particle is not None + force = child.CalculateForceFromTree(target_particle) + total_force += force + + return total_force + + def CalculateForce(self, target_particle, particle, theta=1.0): + """Calculate the gravitational force between two particles. + + Args: + target_particle (Particle): The particle for which the force is calculated. + particle (Particle): The particle exerting the force. + + Returns: + np.ndarray: The force vector acting on the target particle due to 'particle'. + """ + force = np.array([0.0, 0.0]) + print('function CalculateForce is called') + if self.particle is not None: + # Node contains only one particle + if self.particle != target_particle: + # Calculate gravitational force between target_particle and node's particle + force = self.GravitationalForce(target_particle, self.particle) + else: + if target_particle is not None and particle is not None: # Check if both particles are not None + r = np.linalg.norm( + np.array([target_particle.x, target_particle.y]) - np.array([particle.x, particle.y])) + d = max(self.width, self.height) + + if d / r < theta: + # Calculate gravitational force between target_particle and particle + force = self.GravitationalForce(target_particle, particle) + else: + for child in self.children: + if child is not None: + # Recursively calculate force from child nodes + force += child.CalculateForce(target_particle, particle) + return force + + def GravitationalForce(self, particle1, particle2): + """Calculate the gravitational force between two particles. + + Args: + particle1 (Particle): First particle. + particle2 (Particle): Second particle. + + Returns: + np.ndarray: The gravitational force vector between particle1 and particle2. + """ + #G = 6.674 * (10 ** -11) # Gravitational constant + #G = 1 + G = 4 * np.pi**2 # AU^3 / m / yr^2 + + dx = particle2.x - particle1.x + dy = particle2.y - particle1.y + cutoff_radius = 0 + r = max(np.sqrt(dx**2 + dy**2), cutoff_radius) + + force_magnitude = G * particle1.mass * particle2.mass / (r**2) + force_x = force_magnitude * (dx / r) + force_y = force_magnitude * (dy / r) + + return np.array([force_x, force_y]) + + # Helper method to retrieve all particles in the subtree + def particles_in_subtree(self): + """Retrieve all particles in the subtree rooted at this node. + + Returns: + list: A list of particles in the subtree rooted at this node. + """ + particles = [] + if self.particle is not None: + particles.append(self.particle) + else: + for child in self.children: + if child is not None: + particles.extend(child.particles_in_subtree()) + return particles + + def compute_center_of_mass(self): + """Compute the center of mass for the node.""" + print('Function compute_center_of_mass is called') + if self.particle is not None: + self.center_of_mass = np.array([self.particle.x, self.particle.y]) + self.mass = self.particle.mass + else: + total_mass = 0 + center_of_mass_accumulator = np.array([0.0, 0.0]) + + for child in self.children: + if child is not None: + child.compute_center_of_mass() + total_mass += child.mass + center_of_mass_accumulator += child.mass * child.center_of_mass + + if total_mass > 0: + self.center_of_mass = center_of_mass_accumulator / total_mass + self.mass = total_mass + else: + self.center_of_mass = np.array([0.0, 0.0]) + self.mass = 0 diff --git a/jobs/tests/test_bht_algorithm.py b/jobs/tests/test_bht_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5adadec9e73ab752db5bc8b13e9445948f1682 --- /dev/null +++ b/jobs/tests/test_bht_algorithm.py @@ -0,0 +1,153 @@ +""" +This unittest tests implementation of the BHT algorithm +for the restricted three-body (sun-earth-moon) problem. +The sun is fixed at the origin (center of mass). For +simplicity, the moon is assumed to be in the eliptic plane, +so that vectors can be treated as two-dimensional. +""" + +import logging +import unittest + +import matplotlib.pyplot as plt +import numpy as np + +from jobs.src.integrators import verlet +from jobs.src.system import GravitationalSystem +from jobs.src.bht_algorithm import MainApp, Particle + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +# system settings +R_SE = 1 # distance between Earth and the Sun, 1 astronomical unit (AU) +R_L = 2.57e-3 * R_SE # distance between Earth and the Moon + +M_S = 1 # solar mass +M_E = 3.00e-6 * M_S +M_L = 3.69e-8 * M_S + +T_E = 1 # earth yr +T_L = 27.3 / 365.3 * T_E + +G = 4 * np.pi**2 # AU^3 / m / yr^2 + +# simulation settings +T = 0.3 # yr +n_order = 6 +dt = 4**(-n_order) # yr + + +# force computation +def force(q: np.ndarray) -> np.ndarray: + r1 = q[0:2] # Earth coordinates + r2 = q[2:4] # Moon coordinates + + sun = Particle(0, 0, M_S) + earth = Particle(r1[0], r1[1], M_E) + moon = Particle(r2[0], r2[1], M_L) + + particles = [sun, earth, moon] + + barnes_hut = MainApp() # Initialize Barnes-Hut algorithm instance + barnes_hut.BuildTree(particles) # Build the Barnes-Hut tree with particles + barnes_hut.rootNode.ComputeMassDistribution() #Compute the center of mass of the tree nodes + + f_earth = barnes_hut.rootNode.CalculateForceFromTree(earth) + f_moon = barnes_hut.rootNode.CalculateForceFromTree(moon) + + return np.concatenate((f_earth, f_moon), axis=0) + + +class BarnesHutTest(unittest.TestCase): + + def test_bht(self): + """ + Test functionalities of velocity-Verlet algorithm using BHT tree + """ + # vector of r0 and v0 + x0 = np.array([ + R_SE, + 0, + R_SE + R_L, + 0, + 0, + R_SE * 2 * np.pi / T_E, + 0, + 1 / M_E * M_E * R_SE * 2 * np.pi / T_E + 1 * R_L * 2 * np.pi / T_L, + ]) + + system = GravitationalSystem(r0=x0[:4], + v0=x0[4:], + m=np.array([M_E, M_E, M_L, M_L]), + t=np.linspace(0, T, int(T // dt)), + force=force, + solver=verlet) + + t, p, q = system.direct_simulation() + + ## checking total energy conservation + H = np.linalg.norm(p[:,:2], axis=1)**2 / (2 * M_E) + np.linalg.norm(p[:,2:], axis=1)**2 / (2 * M_L) + \ + -G * M_S * M_E / np.linalg.norm(q[:,:2], axis=1) - G * M_S * M_L / np.linalg.norm(q[:,2:], axis=1) + \ + -G * M_E * M_L / np.linalg.norm(q[:,2:] - q[:,:2], axis=1) + + logger.info(f"{H=}") + self.assertTrue(np.greater(1e-7 + np.zeros(H.shape[0]), H - H[0]).all()) + + ## checking total linear momentum conservation + P = p[:, :2] + p[:, 2:] + + self.assertTrue(np.greater(1e-10 + np.zeros(P[0].shape), P - P[0]).all()) + + ## checking total angular momentum conservation + L = np.cross(q[:, :2], p[:, :2]) + np.cross(q[:, 2:], p[:, 2:]) + logger.info(f"{L=}") + self.assertTrue(np.greater(1e-8 + np.zeros(L.shape[0]), L - L[0]).all()) + + ## checking error + dts = [dt, 2 * dt, 4 * dt] + errors = [] + ts = [] + for i in dts: + system = GravitationalSystem(r0=x0[:4], + v0=x0[4:], + m=np.array([M_E, M_E, M_L, M_L]), + t=np.linspace(0, T, int(T // i)), + force=force, + solver=verlet) + + t, p_t, q_t = system.direct_simulation() + + + H = np.linalg.norm(p_t[:,:2], axis=1)**2 / (2 * M_E) + np.linalg.norm(p_t[:,2:], axis=1)**2 / (2 * M_L) + \ + -G * M_S * M_E / np.linalg.norm(q_t[:,:2], axis=1) - G * M_S * M_L / np.linalg.norm(q_t[:,2:], axis=1) + \ + -G * M_E * M_L / np.linalg.norm(q_t[:,2:] - q_t[:,:2], axis=1) + + errors.append((H - H[0]) / i**2) + ts.append(t) + + plt.figure() + plt.plot(ts[0], errors[0], label="dt") + plt.plot(ts[1], errors[1], linestyle='--', label="2*dt") + plt.plot(ts[2], errors[2], linestyle=':', label="4*dt") + plt.xlabel("$t$") + plt.ylabel("$\delta E(t)/(\Delta t)^2$") + plt.legend() + plt.show() + + ## checking time reversal: p -> -p + x0 = np.concatenate((q[-1, :], -1 * p[-1, :] / np.array([M_E, M_E, M_L, M_L])), axis=0) + system = GravitationalSystem(r0=x0[:4], + v0=x0[4:], + m=np.array([M_E, M_E, M_L, M_L]), + t=np.linspace(0, T, int(T // dt)), + force=force, + solver=verlet) + + t, p_reverse, q_reverse = system.direct_simulation() + self.assertTrue(np.greater(1e-10 + np.zeros(4), q_reverse[-1] - q[0]).all()) + self.assertTrue(np.greater(1e-10 + np.zeros(4), p_reverse[-1] - p[0]).all()) + + +if __name__ == '__main__': + unittest.main()