Created
December 8, 2020 13:04
-
-
Save ivan-pi/661c0884069baced35c71e6d5b6fe3ce to your computer and use it in GitHub Desktop.
Ball tree in Fortran following the one by Jake Vanderplas (see https://gist.github.com/jakevdp/5216193). Construction appears to work, querying is broken!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module nheap_mod | |
implicit none | |
private | |
public :: nheap | |
integer, parameter :: wp = kind(1.0d0) | |
type :: nheap | |
real(wp), allocatable :: distances(:,:) | |
integer, allocatable :: indices(:,:) | |
contains | |
procedure :: init => nheap_init | |
procedure :: largest => nheap_largest | |
procedure :: push => nheap_push | |
procedure :: get_arrays => nheap_get_arrays | |
end type | |
contains | |
subroutine nheap_init(self,n_pts,n_nbrs) | |
class(nheap), intent(inout) :: self | |
integer, intent(in) :: n_pts, n_nbrs | |
allocate(self%distances(0:n_pts-1,0:n_nbrs-1)) | |
self%distances = 0.0_wp + huge(self%distances) | |
allocate(self%indices(0:n_pts-1,0:n_nbrs-1)) | |
self%indices = 0 | |
end subroutine | |
function nheap_largest(self,row) result(res) | |
class(nheap), intent(in) :: self | |
integer, intent(in) :: row | |
real(wp) :: res | |
res = self%distances(row,0) | |
end function | |
subroutine nheap_push(self,row,val,i_val) | |
class(nheap), intent(inout) :: self | |
integer, intent(in) :: row, i_val | |
real(wp), intent(in) :: val | |
integer :: sz, i, ic1, ic2, i_swap | |
sz = size(self%distances,dim=2) | |
! check if val shoud be in heap | |
if (val > self%distances(row,0)) then | |
return | |
end if | |
! insert val at position zero | |
self%distances(row,0) = val | |
self%indices(row,0) = i_val | |
! descend the heap, swapping values until the max heap criterion is met | |
i = 0 | |
do | |
ic1 = 2*i + 1 | |
ic2 = ic1 + 1 | |
if (ic1 >= sz) then | |
exit | |
else if (ic2 >= sz) then | |
if (self%distances(row,ic1) > val) then | |
i_swap = ic1 | |
else | |
exit | |
end if | |
else if (self%distances(row,ic1) >= self%distances(row,ic2)) then | |
if (val < self%distances(row,ic1)) then | |
i_swap = ic1 | |
else | |
exit | |
end if | |
else | |
if (val < self%distances(row,ic2)) then | |
i_swap = ic2 | |
else | |
exit | |
end if | |
end if | |
self%distances(row,i) = self%distances(row,i_swap) | |
self%indices(row,i) = self%indices(row,i_swap) | |
i = i_swap | |
end do | |
self%distances(row,i) = val | |
self%indices(row,i) = i_val | |
end subroutine | |
subroutine nheap_get_arrays(self,sort,distances,indices) | |
class(nheap), intent(in) :: self | |
logical, intent(in) :: sort | |
real(wp), intent(out) :: distances(0:size(self%distances,1)-1,0:size(self%distances,2)-1) | |
integer, intent(out) :: indices(0:size(self%indices,1)-1,0:size(self%indices,2)-1) | |
real(wp), allocatable :: d(:) | |
integer, allocatable :: j(:) | |
integer :: n_pts,n_nbrs, i, k | |
n_pts = size(self%distances,1) | |
n_nbrs = size(self%distances,2) | |
allocate(d(0:n_nbrs-1)) | |
allocate(j(0:n_nbrs-1)) | |
if (sort) then | |
do i = 0, n_pts - 1 | |
d = self%distances(i,:) | |
! print *, "d = ", d | |
j = self%indices(i,:) | |
! call quicksort(d,j) | |
call quicksort(d,j) | |
! print *, "d = ", d | |
distances(i,:) = d | |
indices(i,:) = j | |
end do | |
else | |
distances = self%distances | |
indices = self%indices | |
end if | |
end subroutine | |
! quicksort.f -*-f90-*- | |
! Author: t-nissie, some tweaks by 1AdAstra1 | |
! License: GPLv3 | |
! Gist: https://gist.github.com/t-nissie/479f0f16966925fa29ea | |
!! | |
recursive subroutine quicksort(a,perm) | |
implicit none | |
real(wp), intent(inout) :: a(:) | |
integer, intent(inout) :: perm(size(a)) | |
real(wp) x, t | |
integer :: first = 1, last | |
integer i, j, ti | |
last = size(a, 1) | |
x = a( (first+last) / 2 ) ! could overflow if array is really big! | |
i = first | |
j = last | |
do | |
do while (a(i) < x) | |
i=i+1 | |
end do | |
do while (x < a(j)) | |
j=j-1 | |
end do | |
if (i >= j) exit | |
t = a(i); a(i) = a(j); a(j) = t | |
ti = perm(i); perm(i) = perm(j); perm(j) = ti | |
i=i+1 | |
j=j-1 | |
end do | |
if (first < i - 1) call quicksort(a(first : i - 1), perm(first : i - 1)) | |
if (j + 1 < last) call quicksort(a(j + 1 : last), perm(j + 1 : last)) | |
end subroutine quicksort | |
end module | |
module btree_mod | |
use nheap_mod, only: nheap | |
implicit none | |
private | |
public :: wp | |
public :: btree, btree_init, btree_query | |
integer, parameter :: wp = kind(1.0d0) | |
type :: btree | |
real(wp), allocatable :: data(:,:) | |
integer :: leaf_size | |
integer :: n_samples, n_features | |
integer :: n_levels, n_nodes | |
integer, allocatable :: idx_array(:) ! n_samples | |
real(wp), allocatable :: node_radius(:) ! n_nodes | |
integer, allocatable :: node_idx_start(:) ! n_nodes | |
integer, allocatable :: node_idx_end(:) ! n_nodes | |
logical, allocatable :: node_is_leaf(:) ! n_nodes | |
real(wp), allocatable :: node_centroids(:,:) ! n_nodes, n_features | |
contains | |
procedure :: rdist | |
procedure :: min_rdist | |
procedure :: recursive_build | |
procedure :: query_recursive | |
end type | |
contains | |
function btree_init(data,leaf_size) result(self) | |
real(wp), intent(in) :: data(0:,0:) | |
integer, intent(in), optional :: leaf_size | |
type(btree) :: self | |
integer :: i, t | |
allocate(self%data,source=data) | |
print *, lbound(self%data,dim=1), ubound(self%data,dim=1) | |
self%leaf_size = 40 | |
if (present(leaf_size)) self%leaf_size = leaf_size | |
print *, "[btree_init] leaf_size = ", self%leaf_size | |
self%n_samples = size(data,dim=1) | |
self%n_features = size(data,dim=2) | |
print *, "[btree_init] n_samples = ", self%n_samples | |
print *, "[btree_init] n_features = ", self%n_features | |
t = max(1,(self%n_samples - 1)/self%leaf_size) | |
self%n_levels = 1 + int(log2(real(t,wp))) ! floor division | |
self%n_nodes = 2**self%n_levels - 1 | |
print *, "[btree_init] n_levels = ", self%n_levels | |
print *, "[btree_init] n_nodes = ", self%n_nodes | |
! allocate arrays for storage | |
allocate(self%idx_array(0:self%n_samples-1)) | |
do i = 0, self%n_samples - 1 | |
self%idx_array(i) = i | |
end do | |
print *, self%idx_array | |
allocate(self%node_radius(0:self%n_nodes-1)) | |
self%node_radius = 0.0_wp | |
allocate(self%node_idx_start(0:self%n_nodes-1)) | |
self%node_idx_start = 0 | |
allocate(self%node_idx_end(0:self%n_nodes-1)) | |
self%node_idx_end = 0 | |
allocate(self%node_is_leaf(0:self%n_nodes-1)) | |
self%node_is_leaf = .false. | |
allocate(self%node_centroids(0:self%n_nodes-1,0:self%n_features-1)) | |
self%node_centroids = 0.0_wp | |
call self%recursive_build(0,0,self%n_samples) | |
end function | |
recursive subroutine recursive_build(self,i_node,idx_start,idx_end) | |
class(btree), intent(inout) :: self | |
integer, intent(in) :: i_node | |
integer, intent(in) :: idx_start, idx_end | |
integer :: n_mid | |
print *, "i_node,idx_start,idx_end",i_node,idx_start,idx_end | |
call init_node(self,i_node,idx_start,idx_end) | |
if ((2*i_node + 1) >= self%n_nodes) then | |
self%node_is_leaf(i_node) = .true. | |
if ((idx_end - idx_start) > 2*self%leaf_size) then | |
write(*,*) "Internal: memory layout is flawed: not enough nodes allocated" | |
end if | |
else if ((idx_end - idx_start) < 2) then | |
write(*,*) "Internal: memory layout is flawed: too many nodes allocated" | |
self%node_is_leaf(i_node) = .true. | |
else | |
! split node and recursively construct child nodes | |
self%node_is_leaf(i_node) = .false. | |
n_mid = int((idx_end + idx_start)/2) | |
call partition_indices(self%data,self%idx_array,idx_start,idx_end,n_mid) | |
call self%recursive_build(2*i_node+1,idx_start,n_mid) | |
call self%recursive_build(2*i_node+2,n_mid,idx_end) | |
end if | |
end subroutine | |
subroutine init_node(self,i_node,idx_start,idx_end) | |
type(btree), intent(inout) :: self | |
integer, intent(in) :: i_node, idx_start, idx_end | |
integer :: i, j | |
real(wp) :: sq_radius, sq_dist | |
! determine node centroid | |
do j = 0, self%n_features - 1 | |
self%node_centroids(i_node,j) = 0 | |
do i = idx_start, idx_end - 1 | |
self%node_centroids(i_node,j) = self%node_centroids(i_node,j) + & | |
self%data(self%idx_array(i),j) | |
end do | |
self%node_centroids(i_node,j) = self%node_centroids(i_node,j)/real(idx_end - idx_start,wp) | |
end do | |
print *, "node_centroid = ", self%node_centroids(i_node,:) | |
! determine node radius | |
sq_radius = 0 | |
do i = idx_start, idx_end -1 | |
sq_dist = self%rdist(self%node_centroids,i_node,self%data,self%idx_array(i)) | |
sq_radius = max(sq_radius,sq_dist) | |
end do | |
print *, "sq_radius, sq_dist", sq_radius,sq_dist | |
self%node_radius(i_node) = sqrt(sq_radius) | |
self%node_idx_start(i_node) = idx_start | |
self%node_idx_end(i_node) = idx_end | |
print *, "node_radius",self%node_radius(i_node) | |
print *, "node_idx_start",self%node_idx_start(i_node) | |
print *, "node_idx_end",self%node_idx_end(i_node) | |
! nbrhd = se | |
end subroutine | |
function rdist(self,x1,i1,x2,i2) result(d) | |
class(btree), intent(in) :: self | |
real(wp), intent(in) :: x1(0:self%n_nodes-1,0:self%n_features-1) | |
real(wp), intent(in) :: x2(0:,0:) | |
integer, intent(in) :: i1, i2 | |
real(wp) :: d, tmp | |
integer :: k | |
d = 0 | |
do k = 0, self%n_features - 1 | |
tmp = x1(i1,k) - x2(i2,k) | |
d = d + tmp*tmp | |
end do | |
end function | |
function min_rdist(self,i_node,x,j) result(res) | |
class(btree), intent(in) :: self | |
integer, intent(in) :: i_node | |
real(wp), intent(in) :: x(0:,0:) | |
integer, intent(in) :: j | |
real(wp) :: d, res | |
d = self%rdist(self%node_centroids,i_node,x,j) | |
res = (max(0.0_wp,sqrt(d) - self%node_radius(i_node)))**2 | |
end function | |
subroutine partition_indices(data,idx_array,idx_start,idx_end,split_index) | |
real(wp), intent(in) :: data(0:,0:) | |
integer, intent(inout) :: idx_array(0:size(data,dim=1)-1) | |
integer, intent(in) :: idx_start, idx_end, split_index | |
integer :: n_features, split_dim | |
real(wp) :: max_spread, max_val, min_val, val, d1, d2 | |
integer :: i, j, left, right, midindex, tmp | |
! find the split dimension | |
n_features = size(data,dim=2) | |
split_dim = 0 | |
max_spread = 0 | |
do j = 0, n_features-1 | |
max_val = -huge(data) | |
min_val = huge(data) | |
do i = idx_start, idx_end - 1 | |
val = data(idx_array(i),j) | |
max_val = max(max_val,val) | |
min_val = min(min_val,val) | |
end do | |
if ((max_val - min_val) > max_spread) then | |
max_spread = max_val - min_val | |
split_dim = j | |
end if | |
end do | |
! partition using the split dimension | |
left = idx_start | |
right = idx_end - 1 | |
do | |
midindex = left | |
do i = left, right - 1 | |
d1 = data(idx_array(i),split_dim) | |
d2 = data(idx_array(right),split_dim) | |
if (d1 < d2) then | |
tmp = idx_array(i) | |
idx_array(i) = idx_array(midindex) | |
idx_array(midindex) = tmp | |
midindex = midindex + 1 | |
end if | |
end do | |
tmp = idx_array(midindex) | |
idx_array(midindex) = idx_array(right) | |
idx_array(right) = tmp | |
if (midindex == split_index) then | |
exit | |
else if (midindex < split_index) then | |
left = midindex + 1 | |
else | |
right = midindex - 1 | |
end if | |
end do | |
end subroutine | |
real(wp) function log2(x) | |
real(wp), intent(in) :: x | |
log2 = log(x) / log(2._wp) | |
end function | |
subroutine btree_query(self,x,k,sort_results,distances,indices) | |
type(btree), intent(in) :: self | |
real(wp), intent(in) :: x(0:,0:) | |
integer, intent(in) :: k | |
logical, intent(in), optional :: sort_results | |
real(wp), intent(out) :: distances(0:size(x,1)-1,0:k-1) | |
integer, intent(out) :: indices(0:size(x,1)-1,0:k-1) | |
logical :: sort_results_ | |
type(nheap) :: heap | |
integer :: i | |
real(wp) :: sq_dist_LB | |
if (size(x,2) /= self%n_features) then | |
write(*,*) "query data dimension must match training data dimension" | |
error stop 1 | |
end if | |
if (size(self%data,1) < k) then | |
write(*,*) "k must be less than or equal to the number of training points" | |
error stop 1 | |
end if | |
sort_results_ = .true. | |
if (present(sort_results)) sort_results_ = sort_results | |
call heap%init(size(x,1),k) | |
print *, shape(heap%distances) | |
print *, shape(heap%indices) | |
do i = 0, size(x,1) - 1 | |
sq_dist_LB = self%min_rdist(0,x,i) | |
write(*,'(A,I0,A,F8.4)') "sq_dist_LB(",i,") = ", sq_dist_LB | |
call self%query_recursive(0,x,i,heap,sq_dist_LB) | |
end do | |
do i = 0, size(x,1) - 1 | |
print *, heap%indices(i,:) | |
end do | |
call heap%get_arrays(sort_results_,distances,indices) | |
distances = sqrt(distances) | |
end subroutine | |
recursive subroutine query_recursive(self,i_node,x,i_pt,heap,sq_dist_LB) | |
class(btree), intent(in) :: self | |
integer, intent(in) :: i_node | |
real(wp), intent(in) :: x(0:,0:) | |
integer, intent(in) :: i_pt | |
type(nheap), intent(inout) :: heap | |
real(wp), intent(in) :: sq_dist_LB | |
real(wp) :: dist_pt, sq_dist_LB_1, sq_dist_LB_2 | |
integer :: i, i1, i2 | |
! Case 1: query point is outside node radius: | |
! trim it from the query | |
if (sq_dist_LB > heap%largest(i_pt)) then | |
print *, "i_node ", i_node, "Case 1" | |
continue | |
! Case 2: this is a leaf node. Update set of nearby points | |
! | |
else if (self%node_is_leaf(i_node)) then | |
print *, "i_node ", i_node, "Case 2" | |
do i = self%node_idx_start(i_node), self%node_idx_end(i_node) - 1 | |
dist_pt = self%rdist(self%data,self%idx_array(i),x,i_pt) | |
! print *, "dist_pt = ", dist_pt | |
if (dist_pt < heap%largest(i_pt)) then | |
call heap%push(i_pt,dist_pt,self%idx_array(i)) | |
end if | |
end do | |
! Case 3: Node is not a leaf. recursively query subnodes | |
! starting with the closest | |
else | |
print *, "i_node ", i_node, "Case 3" | |
i1 = 2*i_node + 1 | |
i2 = i1 + 1 | |
sq_dist_LB_1 = self%min_rdist(i1,x,i_pt) | |
sq_dist_LB_2 = self%min_rdist(i2,x,i_pt) | |
print *, sq_dist_LB_1, sq_dist_LB_2 | |
call flush() | |
! recursively query subnodes | |
if (sq_dist_LB_1 <= sq_dist_LB_2) then | |
call self%query_recursive(i1,x,i_pt,heap,sq_dist_LB_1) | |
call self%query_recursive(i2,x,i_pt,heap,sq_dist_LB_2) | |
else | |
call self%query_recursive(i2,x,i_pt,heap,sq_dist_LB_2) | |
call self%query_recursive(i1,x,i_pt,heap,sq_dist_LB_1) | |
end if | |
end if | |
end subroutine | |
end module | |
program main | |
use btree_mod, only: btree, btree_init, wp, btree_query | |
use nheap_mod | |
implicit none | |
integer, parameter :: n = 100 | |
integer, parameter :: d = 2 | |
integer, parameter :: ls = 15 | |
real(wp) :: x(0:n-1,0:d-1) | |
type(btree) :: bt | |
integer :: i, funit, is_leaf | |
integer, parameter :: k = 5 | |
real(wp) :: p(3,2), r(0:n-1,0:k-1) | |
integer :: ri(0:n-1,0:k-1) | |
integer :: perm(6) | |
real(wp) :: a(6) | |
type(nheap) :: heap | |
! a = [2._wp,3._wp,1._wp,9._wp,7._wp,5._wp] | |
! perm = [1,2,3,4,5,6] | |
! call heap%init(1,6) | |
! do i = 1, 6 | |
! call heap%push(0,a(i),perm(i)) | |
! print *, i, heap%distances | |
! end do | |
! call heap%get_arrays(sort=.true.,distances=a,indices=perm) | |
! print *, a | |
! print *, perm | |
call random_number(x) | |
x = x*2 - 1 | |
x(:,1) = x(:,1)*0.1_wp | |
x(:,1) = x(:,1) + x(:,0)**2 | |
bt = btree_init(x,ls) | |
print *, "leaf_size: ", bt%leaf_size | |
print *, "nsamples, nfeatures: ", bt%n_samples, bt%n_features | |
print *, "nlevels, nnodes: ", bt%n_levels, bt%n_nodes | |
open(newunit=funit,file="ball.txt") | |
do i = 0, n-1 | |
write(funit,*) x(i,:), bt%idx_array(i) | |
end do | |
close(funit) | |
open(newunit=funit,file="ball_nodes.txt") | |
do i = 0, bt%n_nodes-1 | |
if (bt%node_is_leaf(i)) then | |
is_leaf = 1 | |
else | |
is_leaf = 0 | |
end if | |
write(funit,'(I0,X,I0,X,I0,X,I0,X,F16.10,X,2(F16.10,X))') i, is_leaf, & | |
bt%node_idx_start(i), bt%node_idx_end(i), bt%node_radius(i), bt%node_centroids(i,:) | |
end do | |
close(funit) | |
call btree_query(bt,x,k,sort_results=.true., & | |
distances=r, & | |
indices=ri) | |
do i = 0, n-1 | |
print *, ri(i,:) | |
end do | |
end program |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment