Skip to content
Snippets Groups Projects
Select Git revision
  • 1ceda29b67ac83ec27277f2437b3506f56a0ccd2
  • main default protected
  • nicoa96-main-patch-39695
  • nicoa96-main-patch-29647
  • jt/bhtfin
  • jt/bhtnew
  • revert-adeaff5e
  • nicoa96-main-patch-73354
  • nicoa96-main-patch-47348
  • jt/bht
  • jima1
  • jima
  • na/bhtalgorithmus
  • yuhe
  • jn/nasa-data
  • Nicola
  • kuba
17 results

bht_algorithm_2D.py

Blame
  • nguyed99's avatar
    nguyed99 authored
    43f2e8df
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    bht_algorithm_2D.py 12.65 KiB
    """
    This module implements the Barnes Hut Tree (BHT) Algorithm for 2D data.
    """
    
    import numpy as np
    
    
    class MainApp:
    
        def __init__(self):
            """Initialize the MainApp with a root node."""
            self.rootNode = TreeNode(x=-1, y=-1, width=2, height=2)
    
        def BuildTree(self, particles):
            """Build the Quadtree by inserting particles.
            
            Args:
            particles (list): A list of Particle objects to be inserted into the Quadtree.
            """
            self.ResetTree()  # Empty the tree
            for particle in particles:
                self.rootNode.insert(particle)
    
        def ResetTree(self):
            """Reset the Quadtree by reinitializing the root node."""
            self.rootNode = TreeNode(x=-1, y=-1, width=2, height=2)
    
    
    class Particle:
    
        def __init__(self, x, y, mass):
            """Initialize a Particle with x and y coordinates and mass."""
            self.x = x
            self.y = y
            self.mass = mass
            self.vx = 0.0  # Velocity component in x direction
            self.vy = 0.0  # Velocity component in y direction
            self.fx = 0.0  # Force component in x direction
            self.fy = 0.0  # Force component in y direction
    
    
    class TreeNode:
    
        def __init__(self, x, y, width, height):
            """Initialize a TreeNode representing a quadrant in the Quadtree.
    
            Args:
            x (float): x-coordinate of the node.
            y (float): y-coordinate of the node.
            width (float): Width of the node.
            height (float): Height of the node.
            """
            self.x = x
            self.y = y
            self.width = width
            self.height = height
            self.particle = None  # Particle contained in this node
            self.center_of_mass = np.array([(x + width) / 2, (y + width) / 2])
            self.total_mass = 0
            self.children = np.empty(4, dtype=object)  # Children nodes (SW, SE, NW, NE)
    
        def contains(self, particle):
            """Check if the particle is within the bounds of this node.
            
            Args:
            particle (Particle): The particle to be checked.
    
            Returns:
            bool: True if the particle is within the node's bounds, False otherwise.
            """
            return (self.x <= particle.x < self.x + self.width and self.y <= particle.y < self.y + self.height)
    
        def insert(self, particle):
            """Insert a particle into the Quadtree.
            
            Args:
            particle (Particle): The particle to be inserted.
    
            Returns:
            bool: True if the particle is inserted, False otherwise.
            """
            if not self.contains(particle):
                return False  # Particle doesn't belong in this node
    
            if self.particle is None and all(child is None for child in self.children):
                # If the node is empty and has no children, insert particle here.
                self.particle = particle
                #print(f'particle inserted: x={round(self.particle.x,2)}, y={round(self.particle.y,2)}')
                return True  # Particle inserted in an empty node
    
            if all(child is None for child in self.children):
                # If no children exist, create and insert both particles
                self.subdivide()
                self.insert(self.particle)  # Reinsert existing particle
                self.insert(particle)  # Insert new particle
                self.particle = None  # Clear particle from this node
            else:
                # If the node has children, insert particle in the child node.
                quad_index = self.get_quadrant(particle)
                if self.children[quad_index] is None:
                    # Create a child node if it doesn't exist
                    self.children[quad_index] = TreeNode(self.x + (quad_index % 2) * (self.width / 2),
                                                         self.y + (quad_index // 2) * (self.height / 2), self.width / 2,
                                                         self.height / 2)
                self.children[quad_index].insert(particle)
    
        def subdivide(self):
            """Subdivide the node into four quadrants."""
            sub_width = self.width / 2
            sub_height = self.height / 2
            self.children[0] = TreeNode(self.x, self.y, sub_width, sub_height)  # SW
            self.children[1] = TreeNode(self.x + sub_width, self.y, sub_width, sub_height)  # SE
            self.children[2] = TreeNode(self.x, self.y + sub_height, sub_width, sub_height)  # NW
            self.children[3] = TreeNode(self.x + sub_width, self.y + sub_height, sub_width, sub_height)  # NE
    
        def get_quadrant(self, particle):
            """Determine the quadrant index for a particle based on its position.
            
            Args:
            particle (Particle): The particle to determine the quadrant index for.
    
            Returns:
            int: Quadrant index (0 for SW, 1 for SE, 2 for NW, 3 for NE).
            """
            mid_x = self.x + self.width / 2
            mid_y = self.y + self.height / 2
            quad_index = (particle.x >= mid_x) + 2 * (particle.y >= mid_y)
            return quad_index
    
        def print_tree(self, depth=0):
            """Print the structure of the Quadtree.
            
            Args:
            depth (int): Current depth in the tree (for indentation in print).
            """
            if self.particle:
                print(
                    f"{' ' * depth}Particle at ({round(self.particle.x,2)}, {round(self.particle.y,2)}) in Node ({self.x}, {self.y}), size={self.width}"
                )
            else:
                print(f"{' ' * depth}Node ({self.x}, {self.y}) - Width: {self.width}, Height: {self.height}")
                for child in self.children:
                    if child:
                        child.print_tree(depth + 2)
    
        def ComputeMassDistribution(self):
            """Compute the mass distribution for the tree nodes.
    
            This function calculates the total mass and the center of mass
            for each node in the Quadtree. It's a recursive function that
            computes the mass distribution starting from the current node.
    
            Note:
            This method modifies the 'mass' and 'center_of_mass' attributes
            for each node in the Quadtree.
    
            Returns:
            None
            """
            if self.particle is not None:
                # Node contains only one particle
                self.center_of_mass = np.array([self.particle.x, self.particle.y])
                self.total_mass = self.particle.mass
            else:
                # Multiple particles in node
                total_mass = 0
                center_of_mass_accumulator = np.array([0.0, 0.0])
    
                for child in self.children:
                    if child is not None:
                        # Recursively compute mass distribution for child nodes
                        child.ComputeMassDistribution()
                        total_mass += child.total_mass
                        center_of_mass_accumulator += child.total_mass * child.center_of_mass
    
                if total_mass > 0:
                    self.center_of_mass = center_of_mass_accumulator / total_mass
                    self.total_mass = total_mass
                else:
                    # If total mass is 0 or no child nodes have mass, leave values as default
                    pass
                    #self.center_of_mass = np.array([(x+width)/2, (y+width)/2])
                    #self.total_mass = 0
    
        def CalculateForceFromTree(self, target_particle, theta=1.0):
            """Calculate the force on a target particle using the Barnes-Hut algorithm.
    
            Args:
            target_particle (Particle): The particle for which the force is calculated.
            theta (float): The Barnes-Hut criterion for force approximation.
    
            Returns:
            np.ndarray: The total force acting on the target particle.
            """
    
            total_force = np.array([0.0, 0.0])
    
            if self.particle is not None:
                # Node contains only one particle
                if self.particle != target_particle:
                    # Calculate gravitational force between target_particle and node's particle
                    force = self.GravitationalForce(target_particle, self.particle)
                    total_force += force
            else:
                if self.total_mass == 0:
                    return total_force
    
                r = np.linalg.norm(np.array([target_particle.x, target_particle.y]) - self.center_of_mass)
                d = max(self.width, self.height)
    
                if d / r < theta:
                    # Calculate gravitational force between target_particle and "node particle" representing cluster
                    node_particle = Particle(self.center_of_mass[0], self.center_of_mass[1], self.total_mass)
                    force = self.GravitationalForce(target_particle, node_particle)
                    total_force += force
                else:
                    for child in self.children:
                        if child is not None:
                            # Recursively calculate force from child nodes
                            if target_particle is not None:  # Check if the target_particle is not None
                                force = child.CalculateForceFromTree(target_particle)
                                total_force += force
    
            return total_force
    
        def CalculateForce(self, target_particle, particle, theta=1.0):
            """Calculate the gravitational force between two particles.
    
            Args:
            target_particle (Particle): The particle for which the force is calculated.
            particle (Particle): The particle exerting the force.
    
            Returns:
            np.ndarray: The force vector acting on the target particle due to 'particle'.
            """
            force = np.array([0.0, 0.0])
            print('function CalculateForce is called')
            if self.particle is not None:
                # Node contains only one particle
                if self.particle != target_particle:
                    # Calculate gravitational force between target_particle and node's particle
                    force = self.GravitationalForce(target_particle, self.particle)
            else:
                if target_particle is not None and particle is not None:  # Check if both particles are not None
                    r = np.linalg.norm(
                        np.array([target_particle.x, target_particle.y]) - np.array([particle.x, particle.y]))
                    d = max(self.width, self.height)
    
                    if d / r < theta:
                        # Calculate gravitational force between target_particle and particle
                        force = self.GravitationalForce(target_particle, particle)
                    else:
                        for child in self.children:
                            if child is not None:
                                # Recursively calculate force from child nodes
                                force += child.CalculateForce(target_particle, particle)
            return force
    
        def GravitationalForce(self, particle1, particle2):
            """Calculate the gravitational force between two particles.
    
            Args:
            particle1 (Particle): First particle.
            particle2 (Particle): Second particle.
    
            Returns:
            np.ndarray: The gravitational force vector between particle1 and particle2.
            """
            #G = 6.674 * (10 ** -11)  # Gravitational constant
            #G = 1
            G = 4 * np.pi**2  # AU^3 / m / yr^2
    
            dx = particle2.x - particle1.x
            dy = particle2.y - particle1.y
            cutoff_radius = 0
            r = max(np.sqrt(dx**2 + dy**2), cutoff_radius)
    
            force_magnitude = G * particle1.mass * particle2.mass / (r**2)
            force_x = force_magnitude * (dx / r)
            force_y = force_magnitude * (dy / r)
    
            return np.array([force_x, force_y])
    
        # Helper method to retrieve all particles in the subtree
        def particles_in_subtree(self):
            """Retrieve all particles in the subtree rooted at this node.
    
            Returns:
            list: A list of particles in the subtree rooted at this node.
            """
            particles = []
            if self.particle is not None:
                particles.append(self.particle)
            else:
                for child in self.children:
                    if child is not None:
                        particles.extend(child.particles_in_subtree())
            return particles
    
        def compute_center_of_mass(self):
            """Compute the center of mass for the node."""
            print('Function compute_center_of_mass is called')
            if self.particle is not None:
                self.center_of_mass = np.array([self.particle.x, self.particle.y])
                self.mass = self.particle.mass
            else:
                total_mass = 0
                center_of_mass_accumulator = np.array([0.0, 0.0])
    
                for child in self.children:
                    if child is not None:
                        child.compute_center_of_mass()
                        total_mass += child.mass
                        center_of_mass_accumulator += child.mass * child.center_of_mass
    
                if total_mass > 0:
                    self.center_of_mass = center_of_mass_accumulator / total_mass
                    self.mass = total_mass
                else:
                    self.center_of_mass = np.array([0.0, 0.0])
                    self.mass = 0