Skip to content

Instantly share code, notes, and snippets.

@nsivabalan
Last active December 27, 2015 21:09
Show Gist options
  • Save nsivabalan/7390016 to your computer and use it in GitHub Desktop.
Save nsivabalan/7390016 to your computer and use it in GitHub Desktop.
import java.util.Comparator;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Queue;
/**
KDTrees. Used to find KNearestNeighbour among N Points in a KD Plane in O(logK) time.
Here is the implementation of 2D tress to find K nearest neighbours for any point among N given points.
Run time complexity : O((N+k)logK).
O(NlogK) to generate the initial tree.
O(KlogK) to getKnearestNeighbours.
http://coursera.cs.princeton.edu/algs4/assignments/kdtree.html
Check out the 2nd part in this link. First part is regarding range search for a rectangle. Second part is nearest neighbourhood search.
*/
public class KdTrees {
Node root = null;
class Node{
Point2D data;
Node left;
Node right;
public Node(Point2D data)
{
this.data = data;
this.left = null;
this.right = null;
}
public Point2D getData()
{
return data;
}
public Node getLeft()
{
return left;
}
public Node getRight()
{
return right;
}
}
public void insert(Point2D input)
{
if(input == null) throw new IllegalArgumentException("Argument cannot be null");
if(root == null)
{
root = new Node(input);
return;
}
else{
insertNewNode(root, input, 0);
}
}
private Node insertNewNode(Node root, Point2D input, int level)
{
if(root == null)
{
Node temp = new Node(input);
return temp;
}
if(level%2 == 0)
{
int value = root.getData().compareTo(input);
if(value == -1)
root.right = insertNewNode(root.getRight(), input, level+1);
else if(value == 1)
root.left = insertNewNode(root.getLeft(), input, level +1);
return root;
}
else{
int value = root.getData().compareY(input);
if(value == -1)
root.right = insertNewNode(root.getRight(), input, level+1);
else if(value == 1)
root.left = insertNewNode(root.getLeft(), input, level +1);
return root;
}
}
public void printTree()
{
printTree(root);
}
public void printTree(Node root)
{
if(root == null ) return;
if(root.left != null ) printTree(root.getLeft());
System.out.println(" "+root.getData());
if(root.right != null) printTree(root.getRight());
}
public boolean contains(Point2D input)
{
return contains(root, input, 0);
}
public boolean contains(Node root, Point2D input, int level)
{
if(root == null) return false;
if(level%2 == 0)
{
int value = root.getData().compareTo(input);
if(value == -1)
return contains(root.getRight(), input, level+1);
else if(value == 1)
return contains(root.getLeft(), input, level +1);
else return true;
}
else{
int value = root.getData().compareY(input);
if(value == -1)
return contains(root.getRight(), input, level+1);
else if(value == 1)
return contains(root.getLeft(), input, level +1);
else
return true;
}
}
public Queue<PQNode> getKClosestPoints(Point2D input, int k)
{
Queue<PQNode> kClosest = new PriorityQueue<PQNode>(k, new PQNodeComparator());
getKClosest(root, input, 0, kClosest , k);
return kClosest;
}
public void getKClosest(Node root, Point2D input, int level, Queue<PQNode> kClosest, int k)
{
if(root == null) return;
int pqSize = kClosest.size();
double farthestDist = 0.0;
if(pqSize >= k){
farthestDist = kClosest.peek().distance;
double newDist = root.data.getDistance(input);
if(newDist > farthestDist)
return;
else{
kClosest.remove();
kClosest.add(new PQNode(root, input));
}
}
else{
kClosest.add(new PQNode(root, input));
}
PQNode peek = kClosest.peek();
farthestDist = peek.distance;
//System.out.println(" "+root.data+", peek "+peek.data.getData()+" dist "+farthestDist);
double leftDist = 0;
double rightDist = 0;
if(root.getLeft() != null && root.getRight() != null){
leftDist = root.getLeft().data.getDistance(input);
rightDist = root.getRight().data.getDistance(input);
if(leftDist < rightDist)
{
farthestDist = kClosest.peek().distance;
if(leftDist < farthestDist)
getKClosest(root.getLeft(), input, level+1, kClosest, k);
farthestDist = kClosest.peek().distance;
if(rightDist < farthestDist)
getKClosest(root.getRight(), input, level+1, kClosest, k);
}
else{
farthestDist = kClosest.peek().distance;
if(rightDist < farthestDist)
getKClosest(root.getRight(), input, level+1, kClosest, k);
farthestDist = kClosest.peek().distance;
if(leftDist < farthestDist)
getKClosest(root.getLeft(), input, level+1, kClosest, k);
}
}
else if(root.getLeft()!= null)
getKClosest(root.getLeft(), input, level+1, kClosest, k);
else
getKClosest(root.getRight(), input, level+1, kClosest, k);
}
class PQNode{
Node data;
double distance;
public PQNode(Node data, Point2D basePoint)
{
this.data = data;
this.distance = data.getData().getDistance(basePoint);
}
}
class PQNodeComparator implements Comparator<PQNode>{
public int compare(PQNode node1, PQNode node2)
{
if(node1.distance > node2.distance) return -1;
else if(node1.distance < node2.distance) return 1;
else return 0;
}
}
public static void main(String args[])
{
KdTrees obj = new KdTrees();
Point2D point = new Point2D(7, 2);
obj.insert(point);
point = new Point2D(5, 4);
obj.insert(point);
point = new Point2D(2, 3);
obj.insert(point);
point = new Point2D(4, 7);
obj.insert(point);
point = new Point2D(9, 6);
obj.insert(point);
obj.printTree();
point = new Point2D(5, 6);
Queue<PQNode> kClosest = obj.getKClosestPoints(point, 3);
System.out.println(" --------------------------------- ");
while(!kClosest.isEmpty()){
PQNode temp = kClosest.remove();
System.out.println(" "+temp.data.getData()+" "+temp.distance);
}
System.out.println(" --------------------------------- ");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment