Last active
December 27, 2015 21:09
-
-
Save nsivabalan/7390016 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Comparator; | |
import java.util.Iterator; | |
import java.util.PriorityQueue; | |
import java.util.Queue; | |
/** | |
KDTrees. Used to find KNearestNeighbour among N Points in a KD Plane in O(logK) time. | |
Here is the implementation of 2D tress to find K nearest neighbours for any point among N given points. | |
Run time complexity : O((N+k)logK). | |
O(NlogK) to generate the initial tree. | |
O(KlogK) to getKnearestNeighbours. | |
http://coursera.cs.princeton.edu/algs4/assignments/kdtree.html | |
Check out the 2nd part in this link. First part is regarding range search for a rectangle. Second part is nearest neighbourhood search. | |
*/ | |
public class KdTrees { | |
Node root = null; | |
class Node{ | |
Point2D data; | |
Node left; | |
Node right; | |
public Node(Point2D data) | |
{ | |
this.data = data; | |
this.left = null; | |
this.right = null; | |
} | |
public Point2D getData() | |
{ | |
return data; | |
} | |
public Node getLeft() | |
{ | |
return left; | |
} | |
public Node getRight() | |
{ | |
return right; | |
} | |
} | |
public void insert(Point2D input) | |
{ | |
if(input == null) throw new IllegalArgumentException("Argument cannot be null"); | |
if(root == null) | |
{ | |
root = new Node(input); | |
return; | |
} | |
else{ | |
insertNewNode(root, input, 0); | |
} | |
} | |
private Node insertNewNode(Node root, Point2D input, int level) | |
{ | |
if(root == null) | |
{ | |
Node temp = new Node(input); | |
return temp; | |
} | |
if(level%2 == 0) | |
{ | |
int value = root.getData().compareTo(input); | |
if(value == -1) | |
root.right = insertNewNode(root.getRight(), input, level+1); | |
else if(value == 1) | |
root.left = insertNewNode(root.getLeft(), input, level +1); | |
return root; | |
} | |
else{ | |
int value = root.getData().compareY(input); | |
if(value == -1) | |
root.right = insertNewNode(root.getRight(), input, level+1); | |
else if(value == 1) | |
root.left = insertNewNode(root.getLeft(), input, level +1); | |
return root; | |
} | |
} | |
public void printTree() | |
{ | |
printTree(root); | |
} | |
public void printTree(Node root) | |
{ | |
if(root == null ) return; | |
if(root.left != null ) printTree(root.getLeft()); | |
System.out.println(" "+root.getData()); | |
if(root.right != null) printTree(root.getRight()); | |
} | |
public boolean contains(Point2D input) | |
{ | |
return contains(root, input, 0); | |
} | |
public boolean contains(Node root, Point2D input, int level) | |
{ | |
if(root == null) return false; | |
if(level%2 == 0) | |
{ | |
int value = root.getData().compareTo(input); | |
if(value == -1) | |
return contains(root.getRight(), input, level+1); | |
else if(value == 1) | |
return contains(root.getLeft(), input, level +1); | |
else return true; | |
} | |
else{ | |
int value = root.getData().compareY(input); | |
if(value == -1) | |
return contains(root.getRight(), input, level+1); | |
else if(value == 1) | |
return contains(root.getLeft(), input, level +1); | |
else | |
return true; | |
} | |
} | |
public Queue<PQNode> getKClosestPoints(Point2D input, int k) | |
{ | |
Queue<PQNode> kClosest = new PriorityQueue<PQNode>(k, new PQNodeComparator()); | |
getKClosest(root, input, 0, kClosest , k); | |
return kClosest; | |
} | |
public void getKClosest(Node root, Point2D input, int level, Queue<PQNode> kClosest, int k) | |
{ | |
if(root == null) return; | |
int pqSize = kClosest.size(); | |
double farthestDist = 0.0; | |
if(pqSize >= k){ | |
farthestDist = kClosest.peek().distance; | |
double newDist = root.data.getDistance(input); | |
if(newDist > farthestDist) | |
return; | |
else{ | |
kClosest.remove(); | |
kClosest.add(new PQNode(root, input)); | |
} | |
} | |
else{ | |
kClosest.add(new PQNode(root, input)); | |
} | |
PQNode peek = kClosest.peek(); | |
farthestDist = peek.distance; | |
//System.out.println(" "+root.data+", peek "+peek.data.getData()+" dist "+farthestDist); | |
double leftDist = 0; | |
double rightDist = 0; | |
if(root.getLeft() != null && root.getRight() != null){ | |
leftDist = root.getLeft().data.getDistance(input); | |
rightDist = root.getRight().data.getDistance(input); | |
if(leftDist < rightDist) | |
{ | |
farthestDist = kClosest.peek().distance; | |
if(leftDist < farthestDist) | |
getKClosest(root.getLeft(), input, level+1, kClosest, k); | |
farthestDist = kClosest.peek().distance; | |
if(rightDist < farthestDist) | |
getKClosest(root.getRight(), input, level+1, kClosest, k); | |
} | |
else{ | |
farthestDist = kClosest.peek().distance; | |
if(rightDist < farthestDist) | |
getKClosest(root.getRight(), input, level+1, kClosest, k); | |
farthestDist = kClosest.peek().distance; | |
if(leftDist < farthestDist) | |
getKClosest(root.getLeft(), input, level+1, kClosest, k); | |
} | |
} | |
else if(root.getLeft()!= null) | |
getKClosest(root.getLeft(), input, level+1, kClosest, k); | |
else | |
getKClosest(root.getRight(), input, level+1, kClosest, k); | |
} | |
class PQNode{ | |
Node data; | |
double distance; | |
public PQNode(Node data, Point2D basePoint) | |
{ | |
this.data = data; | |
this.distance = data.getData().getDistance(basePoint); | |
} | |
} | |
class PQNodeComparator implements Comparator<PQNode>{ | |
public int compare(PQNode node1, PQNode node2) | |
{ | |
if(node1.distance > node2.distance) return -1; | |
else if(node1.distance < node2.distance) return 1; | |
else return 0; | |
} | |
} | |
public static void main(String args[]) | |
{ | |
KdTrees obj = new KdTrees(); | |
Point2D point = new Point2D(7, 2); | |
obj.insert(point); | |
point = new Point2D(5, 4); | |
obj.insert(point); | |
point = new Point2D(2, 3); | |
obj.insert(point); | |
point = new Point2D(4, 7); | |
obj.insert(point); | |
point = new Point2D(9, 6); | |
obj.insert(point); | |
obj.printTree(); | |
point = new Point2D(5, 6); | |
Queue<PQNode> kClosest = obj.getKClosestPoints(point, 3); | |
System.out.println(" --------------------------------- "); | |
while(!kClosest.isEmpty()){ | |
PQNode temp = kClosest.remove(); | |
System.out.println(" "+temp.data.getData()+" "+temp.distance); | |
} | |
System.out.println(" --------------------------------- "); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment