Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more. 12.65 KiB
This module implements the Barnes Hut Tree (BHT) Algorithm for 2D data.

import numpy as np

class MainApp:

    def __init__(self):
        """Initialize the MainApp with a root node."""
        self.rootNode = TreeNode(x=-1, y=-1, width=2, height=2)

    def BuildTree(self, particles):
        """Build the Quadtree by inserting particles.
        particles (list): A list of Particle objects to be inserted into the Quadtree.
        self.ResetTree()  # Empty the tree
        for particle in particles:

    def ResetTree(self):
        """Reset the Quadtree by reinitializing the root node."""
        self.rootNode = TreeNode(x=-1, y=-1, width=2, height=2)

class Particle:

    def __init__(self, x, y, mass):
        """Initialize a Particle with x and y coordinates and mass."""
        self.x = x
        self.y = y
        self.mass = mass
        self.vx = 0.0  # Velocity component in x direction
        self.vy = 0.0  # Velocity component in y direction
        self.fx = 0.0  # Force component in x direction
        self.fy = 0.0  # Force component in y direction

class TreeNode:

    def __init__(self, x, y, width, height):
        """Initialize a TreeNode representing a quadrant in the Quadtree.

        x (float): x-coordinate of the node.
        y (float): y-coordinate of the node.
        width (float): Width of the node.
        height (float): Height of the node.
        self.x = x
        self.y = y
        self.width = width
        self.height = height
        self.particle = None  # Particle contained in this node
        self.center_of_mass = np.array([(x + width) / 2, (y + width) / 2])
        self.total_mass = 0
        self.children = np.empty(4, dtype=object)  # Children nodes (SW, SE, NW, NE)

    def contains(self, particle):
        """Check if the particle is within the bounds of this node.
        particle (Particle): The particle to be checked.

        bool: True if the particle is within the node's bounds, False otherwise.
        return (self.x <= particle.x < self.x + self.width and self.y <= particle.y < self.y + self.height)

    def insert(self, particle):
        """Insert a particle into the Quadtree.
        particle (Particle): The particle to be inserted.

        bool: True if the particle is inserted, False otherwise.
        if not self.contains(particle):
            return False  # Particle doesn't belong in this node

        if self.particle is None and all(child is None for child in self.children):
            # If the node is empty and has no children, insert particle here.
            self.particle = particle
            #print(f'particle inserted: x={round(self.particle.x,2)}, y={round(self.particle.y,2)}')
            return True  # Particle inserted in an empty node

        if all(child is None for child in self.children):
            # If no children exist, create and insert both particles
            self.insert(self.particle)  # Reinsert existing particle
            self.insert(particle)  # Insert new particle
            self.particle = None  # Clear particle from this node
            # If the node has children, insert particle in the child node.
            quad_index = self.get_quadrant(particle)
            if self.children[quad_index] is None:
                # Create a child node if it doesn't exist
                self.children[quad_index] = TreeNode(self.x + (quad_index % 2) * (self.width / 2),
                                                     self.y + (quad_index // 2) * (self.height / 2), self.width / 2,
                                                     self.height / 2)

    def subdivide(self):
        """Subdivide the node into four quadrants."""
        sub_width = self.width / 2
        sub_height = self.height / 2
        self.children[0] = TreeNode(self.x, self.y, sub_width, sub_height)  # SW
        self.children[1] = TreeNode(self.x + sub_width, self.y, sub_width, sub_height)  # SE
        self.children[2] = TreeNode(self.x, self.y + sub_height, sub_width, sub_height)  # NW
        self.children[3] = TreeNode(self.x + sub_width, self.y + sub_height, sub_width, sub_height)  # NE

    def get_quadrant(self, particle):
        """Determine the quadrant index for a particle based on its position.
        particle (Particle): The particle to determine the quadrant index for.

        int: Quadrant index (0 for SW, 1 for SE, 2 for NW, 3 for NE).
        mid_x = self.x + self.width / 2
        mid_y = self.y + self.height / 2
        quad_index = (particle.x >= mid_x) + 2 * (particle.y >= mid_y)
        return quad_index

    def print_tree(self, depth=0):
        """Print the structure of the Quadtree.
        depth (int): Current depth in the tree (for indentation in print).
        if self.particle:
                f"{' ' * depth}Particle at ({round(self.particle.x,2)}, {round(self.particle.y,2)}) in Node ({self.x}, {self.y}), size={self.width}"
            print(f"{' ' * depth}Node ({self.x}, {self.y}) - Width: {self.width}, Height: {self.height}")
            for child in self.children:
                if child:
                    child.print_tree(depth + 2)

    def ComputeMassDistribution(self):
        """Compute the mass distribution for the tree nodes.

        This function calculates the total mass and the center of mass
        for each node in the Quadtree. It's a recursive function that
        computes the mass distribution starting from the current node.

        This method modifies the 'mass' and 'center_of_mass' attributes
        for each node in the Quadtree.

        if self.particle is not None:
            # Node contains only one particle
            self.center_of_mass = np.array([self.particle.x, self.particle.y])
            self.total_mass = self.particle.mass
            # Multiple particles in node
            total_mass = 0
            center_of_mass_accumulator = np.array([0.0, 0.0])

            for child in self.children:
                if child is not None:
                    # Recursively compute mass distribution for child nodes
                    total_mass += child.total_mass
                    center_of_mass_accumulator += child.total_mass * child.center_of_mass

            if total_mass > 0:
                self.center_of_mass = center_of_mass_accumulator / total_mass
                self.total_mass = total_mass
                # If total mass is 0 or no child nodes have mass, leave values as default
                #self.center_of_mass = np.array([(x+width)/2, (y+width)/2])
                #self.total_mass = 0

    def CalculateForceFromTree(self, target_particle, theta=1.0):
        """Calculate the force on a target particle using the Barnes-Hut algorithm.

        target_particle (Particle): The particle for which the force is calculated.
        theta (float): The Barnes-Hut criterion for force approximation.

        np.ndarray: The total force acting on the target particle.

        total_force = np.array([0.0, 0.0])

        if self.particle is not None:
            # Node contains only one particle
            if self.particle != target_particle:
                # Calculate gravitational force between target_particle and node's particle
                force = self.GravitationalForce(target_particle, self.particle)
                total_force += force
            if self.total_mass == 0:
                return total_force

            r = np.linalg.norm(np.array([target_particle.x, target_particle.y]) - self.center_of_mass)
            d = max(self.width, self.height)

            if d / r < theta:
                # Calculate gravitational force between target_particle and "node particle" representing cluster
                node_particle = Particle(self.center_of_mass[0], self.center_of_mass[1], self.total_mass)
                force = self.GravitationalForce(target_particle, node_particle)
                total_force += force
                for child in self.children:
                    if child is not None:
                        # Recursively calculate force from child nodes
                        if target_particle is not None:  # Check if the target_particle is not None
                            force = child.CalculateForceFromTree(target_particle)
                            total_force += force

        return total_force

    def CalculateForce(self, target_particle, particle, theta=1.0):
        """Calculate the gravitational force between two particles.

        target_particle (Particle): The particle for which the force is calculated.
        particle (Particle): The particle exerting the force.

        np.ndarray: The force vector acting on the target particle due to 'particle'.
        force = np.array([0.0, 0.0])
        print('function CalculateForce is called')
        if self.particle is not None:
            # Node contains only one particle
            if self.particle != target_particle:
                # Calculate gravitational force between target_particle and node's particle
                force = self.GravitationalForce(target_particle, self.particle)
            if target_particle is not None and particle is not None:  # Check if both particles are not None
                r = np.linalg.norm(
                    np.array([target_particle.x, target_particle.y]) - np.array([particle.x, particle.y]))
                d = max(self.width, self.height)

                if d / r < theta:
                    # Calculate gravitational force between target_particle and particle
                    force = self.GravitationalForce(target_particle, particle)
                    for child in self.children:
                        if child is not None:
                            # Recursively calculate force from child nodes
                            force += child.CalculateForce(target_particle, particle)
        return force

    def GravitationalForce(self, particle1, particle2):
        """Calculate the gravitational force between two particles.

        particle1 (Particle): First particle.
        particle2 (Particle): Second particle.

        np.ndarray: The gravitational force vector between particle1 and particle2.
        #G = 6.674 * (10 ** -11)  # Gravitational constant
        #G = 1
        G = 4 * np.pi**2  # AU^3 / m / yr^2

        dx = particle2.x - particle1.x
        dy = particle2.y - particle1.y
        cutoff_radius = 0
        r = max(np.sqrt(dx**2 + dy**2), cutoff_radius)

        force_magnitude = G * particle1.mass * particle2.mass / (r**2)
        force_x = force_magnitude * (dx / r)
        force_y = force_magnitude * (dy / r)

        return np.array([force_x, force_y])

    # Helper method to retrieve all particles in the subtree
    def particles_in_subtree(self):
        """Retrieve all particles in the subtree rooted at this node.

        list: A list of particles in the subtree rooted at this node.
        particles = []
        if self.particle is not None:
            for child in self.children:
                if child is not None:
        return particles

    def compute_center_of_mass(self):
        """Compute the center of mass for the node."""
        print('Function compute_center_of_mass is called')
        if self.particle is not None:
            self.center_of_mass = np.array([self.particle.x, self.particle.y])
            self.mass = self.particle.mass
            total_mass = 0
            center_of_mass_accumulator = np.array([0.0, 0.0])

            for child in self.children:
                if child is not None:
                    total_mass += child.mass
                    center_of_mass_accumulator += child.mass * child.center_of_mass

            if total_mass > 0:
                self.center_of_mass = center_of_mass_accumulator / total_mass
                self.mass = total_mass
                self.center_of_mass = np.array([0.0, 0.0])
                self.mass = 0