N-Body

from pathlib import Path
import os
import functools
import itertools
import time

from IPython.display import HTML, Image
import matplotlib.pyplot as plt
from numpy import *
from celluloid import Camera
import matplotlib.patches as patches
from scipy.integrate import odeint, solve_ivp
from simple_pid import PID

ROOT = Path("./assets/img/")

if not os.path.exists(ROOT):
    os.makedirs(ROOT)

N-Body problem.

\[\frac{d^2 r}{dt^2} = -G \sum_{i=1}^N \frac{m_i}{|r_i|^3} r_i\]
G = 1e-3
BODY_NUM = 10
MASS = [1e4] + [1]*(BODY_NUM-1)
def get_color(idx, count):
    return plt.get_cmap("hsv", count)(idx)
def get_bodies():

    rng = random.default_rng(seed=0)
    
    R = 1
    V = 1

    state_arr = [[0., 0., 0., 0., 0., 0.]]

    theta_arr = linspace(0, 2*pi, BODY_NUM)[:-1]
    
    for theta in theta_arr:
        x = R*cos(theta)
        y = R*sin(theta)
        z = 0

        vx = -V*sin(theta)
        vy = V*cos(theta)
        vz = 0

        state_arr.append([x, y, z, vx, vy, vz])
        
    state_arr = array(state_arr)
    state_arr[1:] = state_arr[1:] + rng.uniform(-0.1, 0.1, size=state_arr[1:].shape)
    return state_arr
def motion_step(s0, t):
    
    s0 = s0.reshape(-1, 6)
    r0, v0 = s0[:, :3], s0[:, 3:]

    N = r0.shape[0]
    
    mask = ones(shape=(N, N))
    mask = mask - eye(N)
    mask = mask[..., None]

    mass = array(MASS)
    mass = mass[None, :, None]

    
    r = r0[:, None, :] - r0[None, :, :]
    
    eps = 1e-7
    dist_sq = (r**2).sum(axis=-1) + eps**2   # (N, N)
    inv_r3  = 1.0 / (dist_sq * sqrt(dist_sq))
    r *= inv_r3[..., None]
    
    dv = -G*(mask*mass*r).sum(axis=1)
    dr = v0

    s = concat([dr, dv], axis=1)
    s = s.reshape(-1)
    return s
def sim_n_body(tail=False):

    rng = random.default_rng(seed=0)

    T = 5
    t = linspace(0, T, 250)

    s0 = get_bodies()
    
    sol = solve_ivp(lambda t, s: motion_step(s, t), (0, T), s0.reshape(-1), t_eval=t)

    sol = sol.y.transpose(1, 0)
    
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    for fn in [ax.set_xticklabels, ax.set_yticklabels, ax.set_zticklabels]:
        fn([])

    colors = [get_color(idx, BODY_NUM) for idx in range(BODY_NUM)]
    markers = [250] + [25]*(BODY_NUM-1)
    
    camera = Camera(fig)
    
    for idx, step in enumerate(sol):
        step = step.reshape(-1, 6)
        
        r = step[:,:3]
        ax.scatter(r[:,0], r[:,1], r[:,2], color=colors, s=markers)

        if tail:
            line = sol[:idx+1]
            line = line.reshape(line.shape[0], -1, 6)
            line = line.transpose(1, 0, 2)

            for body_idx, body in enumerate(line):
                ax.plot(body[:,0], body[:,1], body[:,2], color=colors[body_idx])
        
        camera.snap()

    anim = camera.animate()
    plt.close()

    gif_path = ROOT / "nbody.gif"  
    anim.save(gif_path, writer="pillow", fps=10)
    return Image(url=gif_path)

sim_n_body(tail=True)

Barnes-Hut

class Body:
    def __init__(self, id, r, v, m):
        self.id = id
        self.r = r
        self.v = v
        self.m = m

class Node:
    def __init__(self, center, half_size):
        self.center = array(center)      # 3D center of cube
        self.half_size = half_size          # half side length
        self.com = zeros(3)  
        self.mass = 0.0
        self.children = None          # 8 octants
        self.body = None
        self.is_leaf = True

    def within(self, body):

        v1 = self.center - self.half_size
        v2 = self.center + self.half_size

        return (body.r > v1).all() and (body.r < v2).all() 

def get_octants(center, half_size):
    child_half_size = half_size/2.
    inc = [-child_half_size, child_half_size]

    nodes = []
    for dx, dy, dz in itertools.product(inc, inc, inc):
        node = Node(center+array([dx, dy, dz]), half_size=child_half_size)
        nodes.append(node)

    return nodes
        

def insert_tree(tree: Node, body: ndarray):

    if not tree.within(body):
        return
    
    if tree.is_leaf and tree.body is None:
        tree.body = body
    elif tree.is_leaf and tree.body is not None:
        tree.is_leaf = False
        tree.children = get_octants(tree.center, tree.half_size)

        for child in tree.children:
            insert_tree(child, tree.body)
            insert_tree(child, body)
        
        tree.body = None
            
    elif not tree.is_leaf:
        for child in tree.children:
            insert_tree(child, body)

def init_tree(bodies):

    bodies = [b.r for b in bodies]
    bodies = array(bodies)
    
    center = bodies.mean(axis=0)
    half_size = (bodies.max() - bodies.min())/2.

    return Node(center, half_size)

def build_tree(bodies):

    tree = init_tree(bodies)
    for body in bodies:
        insert_tree(tree, body)

    return tree

def update_com(tree):

    eps = 1e-7
    
    if tree.is_leaf:
        if tree.body is not None:
            tree.com = tree.body.r
            tree.mass = tree.body.m
        return

    for c in tree.children:
        update_com(c)
        tree.com += c.mass*c.com
        tree.mass += c.mass

    tree.com /= (tree.mass+eps)
G = 1e-3
BH_BODY_NUM = 30
BH_MASS = [1e4] + [1]*(BH_BODY_NUM-1)

def get_bodies_barnes_hut():

    rng = random.default_rng(seed=0)
    
    R = 1
    V = 1

    state_arr = [[0., 0., 0., 0., 0., 0.]]

    theta_arr = linspace(0, 2*pi, BH_BODY_NUM)[:-1]
    
    for theta in theta_arr:
        x = R*cos(theta)
        y = R*sin(theta)
        z = 0

        vx = -V*sin(theta)
        vy = V*cos(theta)
        vz = 0

        state_arr.append([x, y, z, vx, vy, vz])
        
    state_arr = array(state_arr)
    state_arr[1:] = state_arr[1:] + rng.uniform(-0.1, 0.1, size=state_arr[1:].shape)
    
    bodies = []

    for body_idx, body_state in enumerate(state_arr):
        body = Body(id=body_idx, r=body_state[:3], v=body_state[3:], m=BH_MASS[body_idx])
        bodies.append(body)

    return bodies
def state_from_bodies(bodies):
    r = []
    v = []

    for body in bodies:
        r.append(body.r)
        v.append(body.v)

    r, v = array(r), array(v)
    return concat([r, v], axis=1).flatten()

def update_bodies(bodies, s0):

    s0 = s0.reshape(-1, 6)

    for body_idx, body_s in enumerate(s0):

        body = bodies[body_idx]
        body.r = body_s[:3]
        body.v = body_s[3:]

def get_acc(tree, body, theta):
    
    eps = 1e-7

    def get_acc_from_2_bodies(com, mass, body):
        ri = body.r - com
        dist_sq = (ri**2).sum() + eps**2   # (N, N)
        inv_r3  = 1.0 / (dist_sq * sqrt(dist_sq))
        
        return mass * inv_r3 * ri        
    
    if tree.is_leaf:
        if tree.body is None or body.id == tree.body.id:
            return 0

        ri = body.r - tree.body.r
        dist_sq = (ri**2).sum() + eps**2   # (N, N)
        inv_r3  = 1.0 / (dist_sq * sqrt(dist_sq))
        
        return get_acc_from_2_bodies(com=tree.body.r, mass=tree.body.m, body=body)

    s = 2*tree.half_size
    ri = body.r - tree.com
    d = (ri**2).sum()
    angular_size = s/d

    if angular_size < theta:
        return get_acc_from_2_bodies(com=tree.com, mass=tree.mass, body=body)
    
    a = 0.
    for child in tree.children:
        a += get_acc(child, body, theta)
    return a
        

def motion_step_bh(s0, t, bodies, theta=0.3):
    
    update_bodies(bodies, s0)
    tree = build_tree(bodies)
    update_com(tree)
    
    s0 = s0.reshape(-1, 6)
    r0, v0 = s0[:, :3], s0[:, 3:]

    dr = v0
    dv = []

    for body in bodies:
        dv.append(-G*get_acc(tree, body, theta))

    return concat([dr, dv], axis=1).reshape(-1)
def sim_n_body_bh(tail=False, verbose=False):

    theta = 1.
    
    T = 5
    t = linspace(0, T, 250)

    bodies = get_bodies_barnes_hut()
    s0 = state_from_bodies(bodies)

    start = time.time()
    sol = solve_ivp(lambda t, s: motion_step_bh(s, t, bodies, theta=theta), (0, T), s0.reshape(-1), t_eval=t)

    if verbose:
        print(f"ode solved: {time.time() - start}s")
    
    sol = sol.y.transpose(1, 0)
    
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    for fn in [ax.set_xticklabels, ax.set_yticklabels, ax.set_zticklabels]:
        fn([])

    s0 = s0.reshape(-1, 6)

    for fn, dim in zip([ax.set_xlim, ax.set_ylim, ax.set_zlim], [0, 1, 2]):

        min = s0[:, dim].min()
        max = s0[:, dim].max()

        fn((min, max))

    colors = [get_color(idx, BH_BODY_NUM) for idx in range(BH_BODY_NUM)]
    markers = [250] + [25]*(BH_BODY_NUM-1)
    
    camera = Camera(fig)
    
    for idx, step in enumerate(sol):
        step = step.reshape(-1, 6)
        
        r = step[:,:3]
        ax.scatter(r[:,0], r[:,1], r[:,2], color=colors, s=markers)

        if tail:
            line = sol[:idx+1]
            line = line.reshape(line.shape[0], -1, 6)
            line = line.transpose(1, 0, 2)

            for body_idx, body in enumerate(line):
                ax.plot(body[:,0], body[:,1], body[:,2], color=colors[body_idx])
        
        camera.snap()

    anim = camera.animate()
    plt.close()

    gif_path = ROOT / "nbody_bh.gif"  
    anim.save(gif_path, writer="pillow", fps=10)
    return Image(url=gif_path)

sim_n_body_bh(tail=True, verbose=True)
ode solved: 19.222471714019775s