Created
February 15, 2015 19:04
-
-
Save skariel/57f667992746fa20aaa5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module nbody | |
using OctTrees | |
import OctTrees: modify, stop_cond, getx, gety, getz | |
immutable Particle <: AbstractPoint3D | |
_x::Float64 | |
_y::Float64 | |
_z::Float64 | |
_m::Float64 | |
end | |
Particle() = Particle(0., 0., 0., 0.) | |
getx(p::Particle) = p._x | |
gety(p::Particle) = p._y | |
getz(p::Particle) = p._z | |
type World | |
tree::OctTree{Particle} | |
particles::Array{Particle, 1} | |
vx::Array{Float64, 1} | |
vy::Array{Float64, 1} | |
vz::Array{Float64, 1} | |
ax::Array{Float64, 1} | |
ay::Array{Float64, 1} | |
az::Array{Float64, 1} | |
n::Int64 | |
opening_alpha2::Float64 | |
opening_excluded_frac2::Float64 | |
smth2::Float64 | |
end | |
function worldnormal(n::Int64; smth=0.000001, opening_excluded_frac=0.0, opening_alpha=0.7) | |
particles = [Particle(randn(), randn(), randn(), 1./n) for i in 1:n] | |
World( | |
OctTree(Particle; n=trunc(Integer,4.1*n)), | |
particles, | |
zeros(n), | |
zeros(n), | |
zeros(n), | |
zeros(n), | |
zeros(n), | |
zeros(n), | |
n, | |
opening_alpha^2, | |
opening_excluded_frac^2, | |
smth*smth | |
) | |
end | |
function worldspherical(n::Int64; smth=0.0, opening_excluded_frac=0.6, opening_alpha=0.7) | |
particles = Particle[] | |
while length(particles) < n | |
tx = rand()*2.0-1.0 | |
ty = rand()*2.0-1.0 | |
tz = rand()*2.0-1.0 | |
tx*tx+ty*ty+tz*tz < 1.0 && push!(particles, Particle(tx,ty,tz, 1./n)) | |
end | |
World( | |
OctTree(Particle; n=trunc(Integer,4.1*n)), | |
particles, | |
zeros(n), | |
zeros(n), | |
zeros(n), | |
zeros(n), | |
zeros(n), | |
zeros(n), | |
n, | |
opening_alpha^2, | |
opening_excluded_frac^2, | |
smth*smth | |
) | |
end | |
@inline function modify(q::OctTreeNode{Particle}, p::Particle) | |
const total_mass = q.point._m + p._m | |
const newx = (q.point._x*q.point._m + p._x*p._m)/total_mass | |
const newy = (q.point._y*q.point._m + p._y*p._m)/total_mass | |
const newz = (q.point._z*q.point._m + p._z*p._m)/total_mass | |
q.point = Particle(newx, newy, newz, total_mass) | |
end | |
function buildtree(w::World) | |
clear!(w.tree) | |
# calculate new boundries same extent on both x and y | |
minc = Float64(1.e30) | |
maxc = Float64(-1.e30) | |
for i in 1:w.n | |
@inbounds const p = w.particles[i] | |
if p._x < minc | |
minc = p._x | |
end | |
if p._y < minc | |
minc = p._y | |
end | |
if p._z < minc | |
minc = p._z | |
end | |
if p._x > maxc | |
maxc = p._x | |
end | |
if p._y > maxc | |
maxc = p._y | |
end | |
if p._z > maxc | |
maxc = p._z | |
end | |
end | |
r = 0.5*(maxc-minc) | |
md= 0.5*(maxc+minc) | |
initnode!(w.tree.head, r*1.05, md, md, md) | |
insert!(w.tree, w.particles, Modify) | |
end | |
type DataToCalculateAccelOnParticle | |
ax::Float64 | |
ay::Float64 | |
az::Float64 | |
px::Float64 | |
py::Float64 | |
pz::Float64 | |
w::World | |
end | |
@inline function stop_cond(q::OctTreeNode{Particle}, data::DataToCalculateAccelOnParticle) | |
isemptyleaf(q) && return true # empty node, nothing to do | |
if isleaf(q) | |
# we have a single particle in the node | |
q.point._x == data.px && | |
q.point._y == data.py && | |
q.point._z == data.pz && return true | |
const dx = q.point._x - data.px | |
const dx2 = dx*dx | |
const dy = q.point._y - data.py | |
const dy2 = dy*dy | |
const dz = q.point._z - data.pz | |
const dz2 = dz*dz | |
const dr2 = dx2+dy2+dz2+data.w.smth2 | |
const dr = sqrt(dr2) | |
#const denom = (dx2+dy2+dz2+data.w.smth2)^1.5/q.point._m | |
const denom = dr2*dr/q.point._m | |
data.ax += dx/denom | |
data.ay += dy/denom | |
data.az += dz/denom | |
return true | |
end | |
# here q is divided. Check we are not too close to the cell | |
const lx = 2.0*q.r | |
const lx2 = lx*lx | |
const dqx = q.midx - data.px | |
const dqx2 = dqx*dqx | |
const dqy = q.midy - data.py | |
const dqy2 = dqy*dqy | |
const dqz = q.midz - data.pz | |
const dqz2 = dqz*dqz | |
const fac = data.w.opening_excluded_frac2*lx2 | |
dqx2 < fac && dqy2 < fac && dqz2 < fac && return false # we need to further open the node | |
const dx = q.point._x - data.px | |
const dx2 = dx*dx | |
const dy = q.point._y - data.py | |
const dy2 = dy*dy | |
const dz = q.point._z - data.pz | |
const dz2 = dz*dz | |
const r2 = dx2 + dy2 + dz2 | |
lx2/r2 > data.w.opening_alpha2 && return false # we need to further open the node | |
# consider the node, no further need to open it | |
const dr2 = r2+data.w.smth2 | |
const dr = sqrt(dr2) | |
const denom = dr2*dr/q.point._m | |
#const denom = (r2+data.w.smth2)^1.5/q.point._m | |
data.ax += dx/denom | |
data.ay += dy/denom | |
data.az += dz/denom | |
return true | |
end | |
@inline function calculate_accel_on_particle(w::World, particle_ix::Int64) | |
@inbounds const p = w.particles[particle_ix] | |
@inbounds data = DataToCalculateAccelOnParticle(0.0,0.0,0.0,p._x,p._y,p._z,w) | |
map(w.tree, data) | |
@inbounds w.ax[particle_ix] = data.ax | |
@inbounds w.ay[particle_ix] = data.ay | |
@inbounds w.az[particle_ix] = data.az | |
end | |
function calc_accel(w::World) | |
buildtree(w) | |
data = DataToCalculateAccelOnParticle(0.0,0.0,0.0,0.0,0.0,0.0,w) | |
@inbounds for i in 1:w.n | |
const p = w.particles[i] | |
data.ax = 0.0 | |
data.ay = 0.0 | |
data.az = 0.0 | |
data.px = p._x | |
data.py = p._y | |
data.pz = p._z | |
map(w.tree, data) | |
@inbounds w.ax[i] = data.ax | |
@inbounds w.ay[i] = data.ay | |
@inbounds w.az[i] = data.az | |
end | |
end | |
function calc_accel_brute_force(w::World, ixs=1:w.n) | |
ax = zeros(w.n) | |
ay = zeros(w.n) | |
az = zeros(w.n) | |
for i in ixs | |
@inbounds p_i = w.particles[i] | |
for j in 1:w.n | |
i==j && continue | |
@inbounds pj = w.particles[j] | |
const dx = pj._x - p_i._x | |
const dy = pj._y - p_i._y | |
const dz = pj._z - p_i._z | |
const dx2 = dx*dx | |
const dy2 = dy*dy | |
const dz2 = dz*dz | |
const r2 = dx2+dy2+dz2 | |
const r22 = r2+w.smth2 | |
const r21 = sqrt(r21) | |
const denom = r22*r21/p_i._m | |
@inbounds ax[i] += dx/denom | |
@inbounds ay[i] += dy/denom | |
@inbounds az[i] += dz/denom | |
end | |
end | |
ax[ixs],ay[ixs],az[ixs] | |
end | |
using Base.Test | |
function test() | |
w = worldnormal(10000) | |
buildtree(w) | |
#testing all points are insize tree boundaries: | |
for p in w.particles | |
@test p._x > w.tree.head.midx - w.tree.head.r | |
@test p._x < w.tree.head.midx + w.tree.head.r | |
@test p._y > w.tree.head.midy - w.tree.head.r | |
@test p._y < w.tree.head.midy + w.tree.head.r | |
@test p._z > w.tree.head.midz - w.tree.head.r | |
@test p._z < w.tree.head.midz + w.tree.head.r | |
end | |
# testing number of full leafs | |
tot_not_empty = 0 | |
tot_massive_leafs = 0 | |
for n in w.tree.nodes | |
if isfullleaf(n) | |
tot_not_empty += 1 | |
end | |
if n.point._m > 1.e-13 && isleaf(n) | |
tot_massive_leafs += 1 | |
end | |
end | |
@test tot_not_empty == w.n | |
@test tot_massive_leafs == w.n | |
total_mass = 0.0 | |
for n in w.tree.nodes[1:w.tree.number_of_nodes_used] | |
if !n.is_empty | |
total_mass += n.point._m | |
end | |
end | |
@test_approx_eq_eps total_mass 1.0 1.e-4 | |
# testing tree construction | |
for n in w.tree.nodes | |
# testing children radius is 0.5*r | |
if n.is_divided | |
@test_approx_eq n.r/2 n.lxlylz.r | |
@test_approx_eq n.r/2 n.lxlyhz.r | |
@test_approx_eq n.r/2 n.lxhylz.r | |
@test_approx_eq n.r/2 n.lxhyhz.r | |
@test_approx_eq n.r/2 n.hxlylz.r | |
@test_approx_eq n.r/2 n.hxlyhz.r | |
@test_approx_eq n.r/2 n.hxhylz.r | |
@test_approx_eq n.r/2 n.hxhyhz.r | |
end | |
# testing mass in children is mass in parent | |
if n.is_divided | |
parent_mass = n.point._m | |
children_mass = | |
n.lxlylz.point._m + | |
n.lxlyhz.point._m + | |
n.lxhylz.point._m + | |
n.lxhyhz.point._m + | |
n.hxlylz.point._m + | |
n.hxlyhz.point._m + | |
n.hxhylz.point._m + | |
n.hxhyhz.point._m | |
@test_approx_eq_eps parent_mass children_mass 1.e-4 | |
end | |
# testing all divided nodes are empty | |
if n.is_divided | |
@test n.is_empty | |
@test !isleaf(n) | |
@test !isemptyleaf(n) | |
@test !isfullleaf(n) | |
else | |
@test isleaf(n) | |
if isfullleaf(n) | |
@test isfullleaf(n) | |
end | |
end | |
# testing center of mass in children builds com in parent | |
if n.is_divided | |
parent_x = n.point._x | |
parent_y = n.point._y | |
parent_z = n.point._z | |
children_x = ( | |
n.lxlylz.point._x*n.lxlylz.point._m + | |
n.lxlyhz.point._x*n.lxlyhz.point._m + | |
n.lxhylz.point._x*n.lxhylz.point._m + | |
n.lxhyhz.point._x*n.lxhyhz.point._m + | |
n.hxlylz.point._x*n.hxlylz.point._m + | |
n.hxlyhz.point._x*n.hxlyhz.point._m + | |
n.hxhylz.point._x*n.hxhylz.point._m + | |
n.hxhyhz.point._x*n.hxhyhz.point._m | |
)/n.point._m | |
children_y = ( | |
n.lxlylz.point._y*n.lxlylz.point._m + | |
n.lxlyhz.point._y*n.lxlyhz.point._m + | |
n.lxhylz.point._y*n.lxhylz.point._m + | |
n.lxhyhz.point._y*n.lxhyhz.point._m + | |
n.hxlylz.point._y*n.hxlylz.point._m + | |
n.hxlyhz.point._y*n.hxlyhz.point._m + | |
n.hxhylz.point._y*n.hxhylz.point._m + | |
n.hxhyhz.point._y*n.hxhyhz.point._m | |
)/n.point._m | |
children_z = ( | |
n.lxlylz.point._z*n.lxlylz.point._m + | |
n.lxlyhz.point._z*n.lxlyhz.point._m + | |
n.lxhylz.point._z*n.lxhylz.point._m + | |
n.lxhyhz.point._z*n.lxhyhz.point._m + | |
n.hxlylz.point._z*n.hxlylz.point._m + | |
n.hxlyhz.point._z*n.hxlyhz.point._m + | |
n.hxhylz.point._z*n.hxhylz.point._m + | |
n.hxhyhz.point._z*n.hxhyhz.point._m | |
)/n.point._m | |
@test_approx_eq_eps parent_x children_x 1.e-4 | |
@test_approx_eq_eps parent_y children_y 1.e-4 | |
@test_approx_eq_eps parent_z children_z 1.e-4 | |
end | |
end | |
println("*** All tests passed! ***") | |
end | |
test() | |
type TestAcc | |
ax_tree::Array{Float64,1} | |
a_tree::Array{Float64,1} | |
ax_bf::Array{Float64,1} | |
a_bf::Array{Float64,1} | |
ferr::Array{Float64,1} | |
fbelow::Array{Float64,1} | |
end | |
function test_acc(n, nout) | |
ixs = randperm(n)[1:nout] | |
w = worldspherical(n) | |
buildtree(w) | |
for i in ixs | |
calculate_accel_on_particle(w, i) | |
end | |
ax_tree = w.ax[ixs] | |
ay_tree = w.ay[ixs] | |
az_tree = w.az[ixs] | |
a_tree = sqrt(ax_tree.^2+ay_tree.^2+ay_tree.^2) | |
println("calculating BF...") | |
ax_bf, ay_bf, az_bf = calc_accel_brute_force(w, ixs) | |
println("Done!\n") | |
a_bf = sqrt(ax_bf.^2+ay_bf.^2+ay_bf.^2) | |
dax = ax_tree-ax_bf | |
day = ay_tree-ay_bf | |
daz = az_tree-az_bf | |
ferr = sqrt(dax.^2+day.^2+daz.^2)./a_bf.*100.0; | |
sort!(ferr) | |
fbelow = [1:nout]./nout.*100.0 | |
TestAcc( | |
ax_tree, | |
a_tree, | |
ax_bf, | |
a_bf, | |
ferr, | |
fbelow | |
) | |
end | |
function calc_dt(w::World) | |
end | |
end # module |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment