Last active
April 5, 2020 02:04
-
-
Save ngshaohui/988c6538774d38cd0463cdfb45ab7aaa to your computer and use it in GitHub Desktop.
Balanced KDTree implementation in Typescript
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { KDTree, DistanceFunc } from './KDTree' | |
import { expect } from 'chai' | |
import 'mocha' | |
interface Point { | |
x: number | |
y: number | |
} | |
const pointList: Point[] = [ | |
{ x: 7, y: 2 }, | |
{ x: 5, y: 4 }, | |
{ x: 9, y: 6 }, | |
{ x: 4, y: 7 }, | |
{ x: 8, y: 1 }, | |
{ x: 2, y: 3 }, | |
] | |
function distanceFormula(point1: Point, point2: Point): number { | |
return Math.sqrt( | |
Math.pow(point1.x - point2.x, 2) + Math.pow(point1.y - point2.y, 2), | |
) | |
} | |
const compFunc: DistanceFunc<Point> = distanceFormula | |
describe('invalid KDTree', () => { | |
it('should not allow tree to be created', () => { | |
const ls = [{}, {}, {}] | |
function fcn() { | |
new KDTree(ls) | |
} | |
expect(fcn).to.throw(Error, 'Object should not have depth of 0') | |
}) | |
}) | |
describe('valid 2 dimensional KDTree', () => { | |
it('should allow empty list', () => { | |
const tree: KDTree<Point> = new KDTree([]) | |
const queryPoint: Point = { x: 5, y: 5 } | |
expect(tree.getNearest(queryPoint, compFunc)).to.equal(null) | |
}) | |
describe('for KDTree built using pointList', () => { | |
const tree: KDTree<Point> = new KDTree(pointList, ['x', 'y']) | |
it('should find nearest neighbour', () => { | |
const queryPoints: Point[] = [ | |
{ x: 5, y: 5 }, | |
{ x: 2, y: 7 }, | |
{ x: 10, y: 10 }, | |
{ x: 8, y: 1 }, | |
] | |
const solutions: Point[] = [ | |
{ x: 5, y: 4 }, | |
{ x: 4, y: 7 }, | |
{ x: 9, y: 6 }, | |
{ x: 8, y: 1 }, | |
] | |
queryPoints.forEach((queryPoint: Point, idx: number) => { | |
expect(tree.getNearest(queryPoint, compFunc)).to.eql(solutions[idx]) | |
}) | |
}) | |
it('should find nn for queryPoint directly on point', () => { | |
expect(tree.getNearest({ x: 8, y: 1 }, compFunc)).to.eql({ x: 8, y: 1 }) | |
}) | |
}) | |
it('should still work without specifying list of keys', () => { | |
const tree: KDTree<Point> = new KDTree(pointList) | |
const treeKeys = new Set(tree.getKeys()) | |
const keys = new Set(['x', 'y']) | |
expect(treeKeys).to.eql(keys) | |
}) | |
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TreeNode<T> { | |
private left: TreeNode<T> | |
private right: TreeNode<T> | |
private data: T | |
constructor(data: T) { | |
this.setData(data) | |
this.setLeft(null) | |
this.setRight(null) | |
} | |
setLeft(left: TreeNode<T>): void { | |
this.left = left | |
} | |
setRight(right: TreeNode<T>): void { | |
this.right = right | |
} | |
setData(data: T): void { | |
this.data = data | |
} | |
getLeft(): TreeNode<T> { | |
return this.left | |
} | |
getRight(): TreeNode<T> { | |
return this.right | |
} | |
getData(): T { | |
return this.data | |
} | |
} | |
function deepEqual<T>(a: T, b: T): boolean { | |
if (typeof a == 'object' && a != null && typeof b == 'object' && b != null) { | |
var count = [0, 0] | |
for (var key in a) count[0]++ | |
for (var key in b) count[1]++ | |
if (count[0] - count[1] != 0) { | |
return false | |
} | |
for (var key in a) { | |
if (!(key in b) || !deepEqual(a[key], b[key])) { | |
return false | |
} | |
} | |
for (var key in b) { | |
if (!(key in a) || !deepEqual(b[key], a[key])) { | |
return false | |
} | |
} | |
return true | |
} else { | |
return a === b | |
} | |
} | |
interface Champion<T> { | |
distance: number | |
data: T | |
} | |
interface DistanceFunc<T> { | |
(arg0: T, arg1: T): number | |
} | |
class KDTree<T extends {}> { | |
private root: TreeNode<T> | |
private keys: string[] | |
constructor(ls: T[], useKeys?: string[]) { | |
if (ls.length === 0) { | |
this.root = null | |
} else { | |
// use specified keys if provided | |
// defaults to all object keys otherwise | |
this.setKeys(useKeys?.length ? useKeys : Object.keys(ls[0])) | |
const sorted = this.keys.reduce((acc: T[][], key: string) => { | |
// return list sorted according to each key | |
return [ | |
...acc, | |
ls.slice(0).sort((a, b) => { | |
if (a[key] > b[key]) { | |
return 1 | |
} else if (a[key] < b[key]) { | |
return -1 | |
} | |
return 0 | |
}), | |
] | |
}, []) | |
this.setRoot(this.buildTree(sorted, this.keys, 0)) | |
} | |
} | |
private buildTree(ls: T[][], keys: string[], depth: number): TreeNode<T> { | |
if (ls[0].length === 0) { | |
return null | |
} | |
if (ls[0].length === 1) { | |
return new TreeNode(ls[0][0]) | |
} | |
const key = keys[depth % keys.length] | |
const currentList = ls[depth % keys.length] | |
const middleIndex = Math.floor(currentList.length / 2) | |
const currentPoint = currentList[middleIndex] | |
const currentNode = new TreeNode(currentPoint) | |
const left: T[][] = ls.reduce((acc: T[][], xs: T[]) => { | |
return [ | |
...acc, | |
xs.filter((point) => { | |
return point[key] < currentPoint[key] | |
}), | |
] | |
}, []) | |
const right: T[][] = ls.reduce((acc: T[][], xs: T[]) => { | |
return [ | |
...acc, | |
xs.filter((point) => { | |
return ( | |
point[key] >= currentPoint[key] && !deepEqual(point, currentPoint) | |
) | |
}), | |
] | |
}, []) | |
currentNode.setLeft(this.buildTree(left, keys, depth + 1)) | |
currentNode.setRight(this.buildTree(right, keys, depth + 1)) | |
return currentNode | |
} | |
private setRoot(root: TreeNode<T>): void { | |
this.root = root | |
} | |
private setKeys(keys: string[]): void { | |
if (keys.length === 0) { | |
throw new Error('Object should not have depth of 0') | |
} | |
this.keys = keys | |
} | |
getKeys(): string[] { | |
return this.keys | |
} | |
// given a point, find the nearest point to it in the KD Tree | |
getNearest(queryPoint: T, getDistance: DistanceFunc<T>): T | null { | |
if (!this.root) { | |
// if tree is empty | |
return null | |
} | |
return this.getNearestH( | |
this.root, | |
null, | |
queryPoint, | |
this.keys, | |
getDistance, | |
0, | |
).data | |
} | |
private getNearestH( | |
curNode: TreeNode<T>, | |
champion: Champion<T>, | |
queryPoint: T, | |
keys: string[], | |
getDistance: DistanceFunc<T>, | |
depth: number, | |
): Champion<T> { | |
if (!curNode) { | |
return champion | |
} | |
const curDistance = getDistance(queryPoint, curNode.getData()) | |
// maintain champion as least distance from point in tree to queryPoint | |
let curChampion: Champion<T> = | |
!!champion && champion.distance < curDistance | |
? champion | |
: { distance: curDistance, data: curNode.getData() } | |
const key = keys[depth % keys.length] | |
const borderPoint = { | |
...queryPoint, | |
[key]: curNode.getData()[key], | |
} | |
// calculate shortest path to current node's plane | |
const borderDistance = getDistance(borderPoint, queryPoint) | |
if (queryPoint[key] < curNode.getData()[key]) { | |
// go left | |
curChampion = this.getNearestH( | |
curNode.getLeft(), | |
curChampion, | |
queryPoint, | |
keys, | |
getDistance, | |
depth + 1, | |
) | |
// if hypersphere intersects plane | |
if (curChampion.distance > borderDistance) { | |
// still need to explore right subtree | |
curChampion = this.getNearestH( | |
curNode.getRight(), | |
curChampion, | |
queryPoint, | |
keys, | |
getDistance, | |
depth + 1, | |
) | |
} | |
} else { | |
// go right | |
curChampion = this.getNearestH( | |
curNode.getRight(), | |
curChampion, | |
queryPoint, | |
keys, | |
getDistance, | |
depth + 1, | |
) | |
// if hypersphere intersects plane | |
if (curChampion.distance > borderDistance) { | |
// still need to explore left subtree | |
curChampion = this.getNearestH( | |
curNode.getLeft(), | |
curChampion, | |
queryPoint, | |
keys, | |
getDistance, | |
depth + 1, | |
) | |
} | |
} | |
return curChampion | |
} | |
} | |
export { KDTree, DistanceFunc } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment