Skip to content
Snippets Groups Projects
Commit 8595d3e5 authored by jakut77's avatar jakut77
Browse files

BHT test update

parent 2a20043f
No related branches found
No related tags found
No related merge requests found
"""
This module implements the Barnes Hut Tree Algorithm
"""
import numpy as np
class MainApp:
def __init__(self):
"""Initialize the MainApp with a root node."""
self.rootNode = TreeNode(x=-1, y=-1, width=2, height=2)
def BuildTree(self, particles):
"""Build the Quadtree by inserting particles.
Args:
particles (list): A list of Particle objects to be inserted into the Quadtree.
"""
self.ResetTree() # Empty the tree
for particle in particles:
self.rootNode.insert(particle)
def ResetTree(self):
"""Reset the Quadtree by reinitializing the root node."""
self.rootNode = TreeNode(x=-1, y=-1, width=2, height=2)
class Particle:
def __init__(self, x, y, mass):
"""Initialize a Particle with x and y coordinates and mass."""
self.x = x
self.y = y
self.mass = mass
self.vx = 0.0 # Velocity component in x direction
self.vy = 0.0 # Velocity component in y direction
self.fx = 0.0 # Force component in x direction
self.fy = 0.0 # Force component in y direction
class TreeNode:
def __init__(self, x, y, width, height):
"""Initialize a TreeNode representing a quadrant in the Quadtree.
Args:
x (float): x-coordinate of the node.
y (float): y-coordinate of the node.
width (float): Width of the node.
height (float): Height of the node.
"""
self.x = x
self.y = y
self.width = width
self.height = height
self.particle = None # Particle contained in this node
self.center_of_mass = np.array([(x + width) / 2, (y + width) / 2])
self.total_mass = 0
self.children = np.empty(4, dtype=object) # Children nodes (SW, SE, NW, NE)
def contains(self, particle):
"""Check if the particle is within the bounds of this node.
Args:
particle (Particle): The particle to be checked.
Returns:
bool: True if the particle is within the node's bounds, False otherwise.
"""
return (self.x <= particle.x < self.x + self.width and self.y <= particle.y < self.y + self.height)
def insert(self, particle):
"""Insert a particle into the Quadtree.
Args:
particle (Particle): The particle to be inserted.
Returns:
bool: True if the particle is inserted, False otherwise.
"""
if not self.contains(particle):
return False # Particle doesn't belong in this node
if self.particle is None and all(child is None for child in self.children):
# If the node is empty and has no children, insert particle here.
self.particle = particle
#print(f'particle inserted: x={round(self.particle.x,2)}, y={round(self.particle.y,2)}')
return True # Particle inserted in an empty node
if all(child is None for child in self.children):
# If no children exist, create and insert both particles
self.subdivide()
self.insert(self.particle) # Reinsert existing particle
self.insert(particle) # Insert new particle
self.particle = None # Clear particle from this node
else:
# If the node has children, insert particle in the child node.
quad_index = self.get_quadrant(particle)
if self.children[quad_index] is None:
# Create a child node if it doesn't exist
self.children[quad_index] = TreeNode(self.x + (quad_index % 2) * (self.width / 2),
self.y + (quad_index // 2) * (self.height / 2), self.width / 2,
self.height / 2)
self.children[quad_index].insert(particle)
def subdivide(self):
"""Subdivide the node into four quadrants."""
sub_width = self.width / 2
sub_height = self.height / 2
self.children[0] = TreeNode(self.x, self.y, sub_width, sub_height) # SW
self.children[1] = TreeNode(self.x + sub_width, self.y, sub_width, sub_height) # SE
self.children[2] = TreeNode(self.x, self.y + sub_height, sub_width, sub_height) # NW
self.children[3] = TreeNode(self.x + sub_width, self.y + sub_height, sub_width, sub_height) # NE
def get_quadrant(self, particle):
"""Determine the quadrant index for a particle based on its position.
Args:
particle (Particle): The particle to determine the quadrant index for.
Returns:
int: Quadrant index (0 for SW, 1 for SE, 2 for NW, 3 for NE).
"""
mid_x = self.x + self.width / 2
mid_y = self.y + self.height / 2
quad_index = (particle.x >= mid_x) + 2 * (particle.y >= mid_y)
return quad_index
def print_tree(self, depth=0):
"""Print the structure of the Quadtree.
Args:
depth (int): Current depth in the tree (for indentation in print).
"""
if self.particle:
print(
f"{' ' * depth}Particle at ({round(self.particle.x,2)}, {round(self.particle.y,2)}) in Node ({self.x}, {self.y}), size={self.width}"
)
else:
print(f"{' ' * depth}Node ({self.x}, {self.y}) - Width: {self.width}, Height: {self.height}")
for child in self.children:
if child:
child.print_tree(depth + 2)
def ComputeMassDistribution(self):
"""Compute the mass distribution for the tree nodes.
This function calculates the total mass and the center of mass
for each node in the Quadtree. It's a recursive function that
computes the mass distribution starting from the current node.
Note:
This method modifies the 'mass' and 'center_of_mass' attributes
for each node in the Quadtree.
Returns:
None
"""
if self.particle is not None:
# Node contains only one particle
self.center_of_mass = np.array([self.particle.x, self.particle.y])
self.total_mass = self.particle.mass
else:
# Multiple particles in node
total_mass = 0
center_of_mass_accumulator = np.array([0.0, 0.0])
for child in self.children:
if child is not None:
# Recursively compute mass distribution for child nodes
child.ComputeMassDistribution()
total_mass += child.total_mass
center_of_mass_accumulator += child.total_mass * child.center_of_mass
if total_mass > 0:
self.center_of_mass = center_of_mass_accumulator / total_mass
self.total_mass = total_mass
else:
# If total mass is 0 or no child nodes have mass, leave values as default
pass
#self.center_of_mass = np.array([(x+width)/2, (y+width)/2])
#self.total_mass = 0
def CalculateForceFromTree(self, target_particle, theta=1.0):
"""Calculate the force on a target particle using the Barnes-Hut algorithm.
Args:
target_particle (Particle): The particle for which the force is calculated.
theta (float): The Barnes-Hut criterion for force approximation.
Returns:
np.ndarray: The total force acting on the target particle.
"""
total_force = np.array([0.0, 0.0])
if self.particle is not None:
# Node contains only one particle
if self.particle != target_particle:
# Calculate gravitational force between target_particle and node's particle
force = self.GravitationalForce(target_particle, self.particle)
total_force += force
else:
if self.total_mass == 0:
return total_force
r = np.linalg.norm(np.array([target_particle.x, target_particle.y]) - self.center_of_mass)
d = max(self.width, self.height)
if d / r < theta:
# Calculate gravitational force between target_particle and "node particle" representing cluster
node_particle = Particle(self.center_of_mass[0], self.center_of_mass[1], self.total_mass)
force = self.GravitationalForce(target_particle, node_particle)
total_force += force
else:
for child in self.children:
if child is not None:
# Recursively calculate force from child nodes
if target_particle is not None: # Check if the target_particle is not None
force = child.CalculateForceFromTree(target_particle)
total_force += force
return total_force
def CalculateForce(self, target_particle, particle, theta=1.0):
"""Calculate the gravitational force between two particles.
Args:
target_particle (Particle): The particle for which the force is calculated.
particle (Particle): The particle exerting the force.
Returns:
np.ndarray: The force vector acting on the target particle due to 'particle'.
"""
force = np.array([0.0, 0.0])
print('function CalculateForce is called')
if self.particle is not None:
# Node contains only one particle
if self.particle != target_particle:
# Calculate gravitational force between target_particle and node's particle
force = self.GravitationalForce(target_particle, self.particle)
else:
if target_particle is not None and particle is not None: # Check if both particles are not None
r = np.linalg.norm(
np.array([target_particle.x, target_particle.y]) - np.array([particle.x, particle.y]))
d = max(self.width, self.height)
if d / r < theta:
# Calculate gravitational force between target_particle and particle
force = self.GravitationalForce(target_particle, particle)
else:
for child in self.children:
if child is not None:
# Recursively calculate force from child nodes
force += child.CalculateForce(target_particle, particle)
return force
def GravitationalForce(self, particle1, particle2):
"""Calculate the gravitational force between two particles.
Args:
particle1 (Particle): First particle.
particle2 (Particle): Second particle.
Returns:
np.ndarray: The gravitational force vector between particle1 and particle2.
"""
#G = 6.674 * (10 ** -11) # Gravitational constant
#G = 1
G = 4 * np.pi**2 # AU^3 / m / yr^2
dx = particle2.x - particle1.x
dy = particle2.y - particle1.y
cutoff_radius = 0
r = max(np.sqrt(dx**2 + dy**2), cutoff_radius)
force_magnitude = G * particle1.mass * particle2.mass / (r**2)
force_x = force_magnitude * (dx / r)
force_y = force_magnitude * (dy / r)
return np.array([force_x, force_y])
# Helper method to retrieve all particles in the subtree
def particles_in_subtree(self):
"""Retrieve all particles in the subtree rooted at this node.
Returns:
list: A list of particles in the subtree rooted at this node.
"""
particles = []
if self.particle is not None:
particles.append(self.particle)
else:
for child in self.children:
if child is not None:
particles.extend(child.particles_in_subtree())
return particles
def compute_center_of_mass(self):
"""Compute the center of mass for the node."""
print('Function compute_center_of_mass is called')
if self.particle is not None:
self.center_of_mass = np.array([self.particle.x, self.particle.y])
self.mass = self.particle.mass
else:
total_mass = 0
center_of_mass_accumulator = np.array([0.0, 0.0])
for child in self.children:
if child is not None:
child.compute_center_of_mass()
total_mass += child.mass
center_of_mass_accumulator += child.mass * child.center_of_mass
if total_mass > 0:
self.center_of_mass = center_of_mass_accumulator / total_mass
self.mass = total_mass
else:
self.center_of_mass = np.array([0.0, 0.0])
self.mass = 0
"""
This unittest tests implementation of the BHT algorithm
for the restricted three-body (sun-earth-moon) problem.
The sun is fixed at the origin (center of mass). For
simplicity, the moon is assumed to be in the eliptic plane,
so that vectors can be treated as two-dimensional.
"""
import logging
import unittest
import matplotlib.pyplot as plt
import numpy as np
from jobs.src.integrators import verlet
from jobs.src.system import GravitationalSystem
from jobs.src.bht_algorithm import MainApp, Particle
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# system settings
R_SE = 1 # distance between Earth and the Sun, 1 astronomical unit (AU)
R_L = 2.57e-3 * R_SE # distance between Earth and the Moon
M_S = 1 # solar mass
M_E = 3.00e-6 * M_S
M_L = 3.69e-8 * M_S
T_E = 1 # earth yr
T_L = 27.3 / 365.3 * T_E
G = 4 * np.pi**2 # AU^3 / m / yr^2
# simulation settings
T = 0.3 # yr
n_order = 6
dt = 4**(-n_order) # yr
# force computation
def force(q: np.ndarray) -> np.ndarray:
r1 = q[0:2] # Earth coordinates
r2 = q[2:4] # Moon coordinates
sun = Particle(0, 0, M_S)
earth = Particle(r1[0], r1[1], M_E)
moon = Particle(r2[0], r2[1], M_L)
particles = [sun, earth, moon]
barnes_hut = MainApp() # Initialize Barnes-Hut algorithm instance
barnes_hut.BuildTree(particles) # Build the Barnes-Hut tree with particles
barnes_hut.rootNode.ComputeMassDistribution() #Compute the center of mass of the tree nodes
f_earth = barnes_hut.rootNode.CalculateForceFromTree(earth)
f_moon = barnes_hut.rootNode.CalculateForceFromTree(moon)
return np.concatenate((f_earth, f_moon), axis=0)
class BarnesHutTest(unittest.TestCase):
def test_bht(self):
"""
Test functionalities of velocity-Verlet algorithm using BHT tree
"""
# vector of r0 and v0
x0 = np.array([
R_SE,
0,
R_SE + R_L,
0,
0,
R_SE * 2 * np.pi / T_E,
0,
1 / M_E * M_E * R_SE * 2 * np.pi / T_E + 1 * R_L * 2 * np.pi / T_L,
])
system = GravitationalSystem(r0=x0[:4],
v0=x0[4:],
m=np.array([M_E, M_E, M_L, M_L]),
t=np.linspace(0, T, int(T // dt)),
force=force,
solver=verlet)
t, p, q = system.direct_simulation()
## checking total energy conservation
H = np.linalg.norm(p[:,:2], axis=1)**2 / (2 * M_E) + np.linalg.norm(p[:,2:], axis=1)**2 / (2 * M_L) + \
-G * M_S * M_E / np.linalg.norm(q[:,:2], axis=1) - G * M_S * M_L / np.linalg.norm(q[:,2:], axis=1) + \
-G * M_E * M_L / np.linalg.norm(q[:,2:] - q[:,:2], axis=1)
logger.info(f"{H=}")
self.assertTrue(np.greater(1e-7 + np.zeros(H.shape[0]), H - H[0]).all())
## checking total linear momentum conservation
P = p[:, :2] + p[:, 2:]
self.assertTrue(np.greater(1e-10 + np.zeros(P[0].shape), P - P[0]).all())
## checking total angular momentum conservation
L = np.cross(q[:, :2], p[:, :2]) + np.cross(q[:, 2:], p[:, 2:])
logger.info(f"{L=}")
self.assertTrue(np.greater(1e-8 + np.zeros(L.shape[0]), L - L[0]).all())
## checking error
dts = [dt, 2 * dt, 4 * dt]
errors = []
ts = []
for i in dts:
system = GravitationalSystem(r0=x0[:4],
v0=x0[4:],
m=np.array([M_E, M_E, M_L, M_L]),
t=np.linspace(0, T, int(T // i)),
force=force,
solver=verlet)
t, p_t, q_t = system.direct_simulation()
H = np.linalg.norm(p_t[:,:2], axis=1)**2 / (2 * M_E) + np.linalg.norm(p_t[:,2:], axis=1)**2 / (2 * M_L) + \
-G * M_S * M_E / np.linalg.norm(q_t[:,:2], axis=1) - G * M_S * M_L / np.linalg.norm(q_t[:,2:], axis=1) + \
-G * M_E * M_L / np.linalg.norm(q_t[:,2:] - q_t[:,:2], axis=1)
errors.append((H - H[0]) / i**2)
ts.append(t)
plt.figure()
plt.plot(ts[0], errors[0], label="dt")
plt.plot(ts[1], errors[1], linestyle='--', label="2*dt")
plt.plot(ts[2], errors[2], linestyle=':', label="4*dt")
plt.xlabel("$t$")
plt.ylabel("$\delta E(t)/(\Delta t)^2$")
plt.legend()
plt.show()
## checking time reversal: p -> -p
x0 = np.concatenate((q[-1, :], -1 * p[-1, :] / np.array([M_E, M_E, M_L, M_L])), axis=0)
system = GravitationalSystem(r0=x0[:4],
v0=x0[4:],
m=np.array([M_E, M_E, M_L, M_L]),
t=np.linspace(0, T, int(T // dt)),
force=force,
solver=verlet)
t, p_reverse, q_reverse = system.direct_simulation()
self.assertTrue(np.greater(1e-10 + np.zeros(4), q_reverse[-1] - q[0]).all())
self.assertTrue(np.greater(1e-10 + np.zeros(4), p_reverse[-1] - p[0]).all())
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment