Skip to content

Instantly share code, notes, and snippets.

@simondlevy
Last active January 1, 2021 04:26
Show Gist options
  • Save simondlevy/54595acedad817305c0c0faaad23350d to your computer and use it in GitHub Desktop.
Save simondlevy/54595acedad817305c0c0faaad23350d to your computer and use it in GitHub Desktop.
Demon rtree library for k Nearest Neighbors
#!/usr/bin/env python3
'''
Tests k-Nearest-Neighbors using Rtree library
Requires: numpy, matplotlib, rtree
Copyright (C) Simon D. Levy 2020
MIT License
'''
import argparse
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from rtree import index
def twod(n, k):
# Create an Rtree Index for inserting the points
idx = index.Index()
# Generate a bunch of random points in the unit square
pts = np.random.random((n,2))
# Insert the points
for j,pt in enumerate(pts):
idx.insert(j, (pt[0], pt[1], pt[0], pt[1]))
# Get the indices of the k nearest neighbors to the center point, and use them as indices into the
# full set.
nbrs = pts[list(idx.nearest((0.5, 0.5, 0.5, 0.5), k)),:]
# Plot the full set of points
plt.scatter(pts[:,0], pts[:,1], marker='.')
# Plot the neighbors in red
plt.scatter(nbrs[:,0], nbrs[:,1], marker='.', color='r')
# Show everything in nice square axes
plt.axis('square')
def threed(n, k):
# Create a 3D Rtree Index for inserting the points
p = index.Property()
p.dimension = 3
idx = index.Index(properties=p, interleaved=False)
# Generate a bunch of random points in the unit square
pts = np.random.random((n,3))
# Insert the points
for j,pt in enumerate(pts):
# With interleaved=False, the order of input and output is: (xmin, xmax, ymin, ymax, zmin, zmax)
idx.insert(j, (pt[0], pt[0], pt[1], pt[1], pt[2], pt[2]))
# Get the indices of the k nearest neighbors to the center point
nbrs = list(idx.nearest((0.5, 0.5, 0.5, 0.5, 0.5, 0.5), k))
# Get the indices of the non-neighbor points
nons = list(set(range(n)) - set(nbrs))
# Create axes for 3D plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Plot the non-neighbor set of points
ax.scatter(pts[nons,0], pts[nons,1], pts[nons,2], marker='.')
# Plot the neighbors in red
ax.scatter(pts[nbrs,0], pts[nbrs,1], pts[nbrs,2], marker='.', color='r')
def main():
# Get n, k from command line
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--n', type=int, required=False, default=1000, help='Total number of points')
parser.add_argument('--k', type=int, required=False, default=50, help='Number of neighbors')
parser.add_argument('--seed', type=int, required=False, default=None, help='Seed for random number generator')
parser.add_argument('--3d', dest='threed', action='store_true', help='3D version')
args = parser.parse_args()
# Seed the random number generator
np.random.seed(args.seed)
if args.threed:
threed(args.n, args.k)
else:
twod(args.n, args.k)
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment