Skip to content

Instantly share code, notes, and snippets.

@inducer
Last active October 21, 2019 23:22
Show Gist options
  • Save inducer/063f77ec77928c4783f366d9298158d1 to your computer and use it in GitHub Desktop.
Save inducer/063f77ec77928c4783f366d9298158d1 to your computer and use it in GitHub Desktop.
Simple numpy FMM Tree Build
import numpy as np
class Tree:
def __init__(
self, box_centers, root_box_extent,
box_parents, box_children, box_levels):
self.box_centers = box_centers
self.root_box_extent = root_box_extent
self.box_parents = box_parents
self.box_children = box_children
self.box_levels = box_levels
@property
def dim(self):
return self.box_centers.shape[0]
@property
def nboxes(self):
return self.box_centers.shape[1]
def is_leaf(self):
return np.all(self.box_children == 0, axis=0)
def make_root(dim):
return Tree(
box_centers=np.array(dim*[0.5]).reshape(dim, 1),
root_box_extent=1,
box_parents=np.array([0], np.intp),
box_children=np.array([0] * 2**dim, np.intp).reshape(2**dim, 1),
box_levels=np.array([0]),
)
def resized_array(ary, new_size, ):
old_size = ary.shape[-1]
shape = list(ary.shape)
shape[-1] = new_size
shape = tuple(shape)
result = np.empty(shape, dtype=ary.dtype)
result[..., :old_size] = ary
if ary.dtype.kind == "i":
result[..., old_size:] = 0
elif ary.dtype.kind == "f":
result[..., old_size:] = float("nan")
else:
raise TypeError("unexpected dtype")
return result
def vec_of_signs(dim, value):
return np.array(
[1 if value & (1 << i) else -1
for i in range(dim)],
dtype=np.float64)
def refine_tree(tree, refine_flags):
nchildren = 2**tree.dim
refine_parents, = np.where(refine_flags)
n_new_boxes = len(refine_parents) * nchildren
nboxes_new = tree.nboxes + n_new_boxes
if refine_flags[~tree.is_leaf()].any():
raise ValueError("attempting to split non-leaf")
child_box_starts = (
tree.nboxes
+ nchildren * np.arange(len(refine_parents)))
refine_parents_per_child = np.empty(
(nchildren, len(refine_parents)),
np.intp)
refine_parents_per_child[:] = refine_parents.reshape(-1)
refine_parents_per_child = refine_parents_per_child.reshape(-1)
box_parents = resized_array(tree.box_parents, nboxes_new)
box_centers = resized_array(tree.box_centers, nboxes_new)
box_children = resized_array(tree.box_children, nboxes_new)
box_levels = resized_array(tree.box_levels, nboxes_new)
box_parents[tree.nboxes:] = refine_parents_per_child
box_levels[tree.nboxes:] = tree.box_levels[box_parents[tree.nboxes:]] + 1
box_children[:, refine_parents] = (
child_box_starts
+ np.arange(nchildren).reshape(-1, 1))
for i in range(2**tree.dim):
children_i = box_children[i, refine_parents]
offsets = (
tree.root_box_extent
* vec_of_signs(tree.dim, i).reshape(-1, 1)
* (1/2**(1+box_levels[children_i])))
box_centers[:, children_i] = (
box_centers[:, refine_parents]
+ offsets)
return Tree(
box_centers=box_centers,
root_box_extent=tree.root_box_extent,
box_parents=box_parents,
box_children=box_children,
box_levels=box_levels)
def main():
from time import time
start = time()
tree = make_root(dim=3)
for i in range(10):
refine_flags = np.zeros(tree.nboxes, np.bool)
if i < 5:
refine_flags[tree.is_leaf()] = 1
else:
nleaves = np.sum(tree.is_leaf().astype(np.intp))
refine_flags[tree.is_leaf()] = np.random.randint(
0, 2, size=nleaves)
tree = refine_tree(tree, refine_flags)
print(f"Tree with {tree.nboxes/1e6:.2f} M boxes in {time()-start:.1f}s")
if tree.nboxes < 50_000 and tree.dim == 2:
import matplotlib.pyplot as plt
plt.plot(tree.box_centers[0], tree.box_centers[1], "o")
plt.show()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment