Created
November 29, 2020 00:02
-
-
Save sidd607/ba7c4d4b959a8454bd91281e37bc636f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node: | |
def __init__ (self, info): | |
self.left = None | |
self.right = None | |
self.info = info | |
def addLeft(self, node): | |
self.left = node | |
def getLeft(self): | |
return self.left | |
def addRight(self, node): | |
self.right = node | |
def getRight(self): | |
return self.right | |
def getInfo(self): | |
return self.info | |
def setInfo(self, info): | |
self.info = info | |
def hasInfo(self, info): | |
if len(self.info) != len(info): | |
return False | |
for i in range(len(self.info)): | |
if (self.info[i] != info[i]): | |
return False | |
return True | |
class KDTree: | |
def __init__ (self, dimensions, name): | |
self.dimensions = dimensions | |
self.name = name | |
self.root = None | |
def _add_node_rec(self, root, node, depth): | |
if root is None: | |
return node | |
dim = depth % self.dimensions | |
# print(dim, self.dimensions) | |
# print (root.getInfo()) | |
# print (node.getInfo()) | |
if node.getInfo()[dim] < root.getInfo()[dim] : | |
root.addLeft(self._add_node_rec(root.getLeft(), node, depth + 1)) | |
else: | |
root.addRight(self._add_node_rec(root.getRight(), node, depth + 1)) | |
return root | |
def add_node(self, node): | |
self.root = self._add_node_rec (self.root, node, 0) | |
def print_inorder(self, node): | |
if node is None: | |
return | |
self.print_inorder(node.getLeft()) | |
print ("-> " + str(node.getInfo())) | |
if node.getLeft() is not None: | |
print ("--> Left: " + str(node.getLeft().getInfo())) | |
if node.getRight() is not None: | |
print ("--> Right: " + str(node.getRight().getInfo())) | |
self.print_inorder(node.getRight()) | |
def _min_node(self, x, y, z, d): | |
res = x | |
if y is not None and y.getInfo()[d] < res.getInfo()[d]: | |
res = y | |
if z is not None and z.getInfo()[d] < res.getInfo()[d]: | |
res = z | |
return res | |
def _find_min_rec(self, root, d, depth): | |
if root is None: | |
return None | |
dim = depth % self.dimensions | |
if d == dim: | |
if root.getLeft() is None: | |
return root | |
return self._find_min_rec(root.getLeft(), d, depth+1) | |
return self._min_node(root, | |
self._find_min_rec(root.getLeft(), d, depth+1), | |
self._find_min_rec(root.getRight(), d, depth+1), | |
d) | |
def find_min(self, root, d): | |
return self._find_min_rec(root, d, 0) | |
def _del_node_rec(self, root, point, depth): | |
print (point) | |
print root.getInfo() | |
print ("===") | |
if root is None: | |
return None | |
dim = depth % self.dimensions | |
if root.hasInfo(point): | |
if root.getRight() is not None: | |
minNode = self.find_min(root.getRight(), dim) | |
root.setInfo(minNode.getInfo()) | |
root.addRight(self._del_node_rec(root.getRight(), minNode.getInfo(), depth+1)) | |
elif root.getLeft() is not None: | |
minNode = self.find_min(root.getLeft(), dim) | |
root.setInfo(minNode.getInfo()) | |
root.addRight(self._del_node_rec(root.getLeft(), minNode.getInfo(), depth+1)) | |
else: | |
root = None | |
return None | |
return root | |
if point[dim] < root.getInfo()[dim]: | |
root.addLeft(self._del_node_rec(root.getLeft(), point, depth+1)) | |
else: | |
root.addRight(self._del_node_rec(root.getRight(), point, depth+1)) | |
return root | |
def delete_Node(self, node): | |
print(self.root) | |
self.root = self._del_node_rec(self.root, node.getInfo(), 0) | |
def print_tree(self): | |
print ("#### InOrder Representation of the Tree : " + self.name + " ####") | |
self.print_inorder(self.root) | |
if __name__ == '__main__': | |
tree = KDTree(2, 'example kdtree') | |
tree.add_node(Node([30, 40])) | |
tree.add_node(Node([5, 25])) | |
tree.add_node(Node([70, 70])) | |
tree.add_node(Node([10, 12])) | |
tree.add_node(Node([50, 30])) | |
tree.add_node(Node([35, 45])) | |
tree.print_tree() | |
print("\n"*2) | |
# print("DELETING NODE (70, 70)") | |
tree.delete_Node(Node([70, 70])) | |
tree.print_tree() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment