Skip to content

Instantly share code, notes, and snippets.

@ShahOdin
Last active January 17, 2020 01:05
Show Gist options
  • Save ShahOdin/61e9761d1fc44e2401c21b1e5ab8f4a8 to your computer and use it in GitHub Desktop.
Save ShahOdin/61e9761d1fc44e2401c21b1e5ab8f4a8 to your computer and use it in GitHub Desktop.
kdtree implementation
object Demo {
def getMedian(ints:List[Int]): Int = ints.sorted.apply(ints.length / 2)
//sorry no shapless here. all these lists have the same length. in two dimensions they will have length two.
case class Point(coordinates: List[Int])
object Point{
def in2D(x: Int, y: Int): Point = Point(List(x,y))
def estimateNearestPoint(points: List[Point], point: Point): Point = {
val indexOfPointWithSmallestDistance: Int = points
.map(_.coordinates.zip(point.coordinates))
.map(_.map(pair =>pair._2 - pair._1))
.map(_.sum)
.zipWithIndex
.minBy(_._1)
._2
points(indexOfPointWithSmallestDistance)
}
}
import Point._
sealed trait Node{
def isBranch: Boolean //ugly function for porting to languages without pattern-matching
def points: List[Point] //ugly function for porting to languages without pattern-matching
def whichNodeToLookNext(point: Point): Node //ugly function for porting to languages without pattern-matching
}
object Node {
def navigateNodeAndFindNearestPoint(node: Node, point: Point): Point = node match {
case l: Leaf => estimateNearestPoint(l.points, point)
case b: Branch => navigateNodeAndFindNearestPoint(b.whichNodeToLookNext(point), point)
}
def navigateNodeAndFindNearestPoint2(node: Node, point: Point): Point = if(node.isBranch) {
navigateNodeAndFindNearestPoint2(node.whichNodeToLookNext(point), point)
} else {
estimateNearestPoint(node.points , point)
}
case class Branch(right: Node, left: Node, decide: Point => Boolean) extends Node {
def isBranch: Boolean = true
def whichNodeToLookNext(point: Point): Node = if(decide(point)) right else left
override def points: List[Point] = List()
}
case class Leaf(points: List[Point]) extends Node {
def isBranch: Boolean = false //fake impl
def whichNodeToLookNext(point: Point): Node = this //fake impl
}
val K = 2
def canSplitFurther(points: List[Point]): Boolean = {
points.length >= math.pow(2, K)
}
def trainData(points: List[Point]): Node = {
assert(points.forall(_.coordinates.length == K))
def splitBasedOnCoordinate(coordinateIndex: Int, points: List[Point]): Node = {
assert(coordinateIndex >= 0 && coordinateIndex < K)
if(canSplitFurther(points)) {
val cs: List[Int] = points.map{
point =>
point.coordinates(coordinateIndex)
}
val median = getMedian(cs)
val rightPoints: List[Point] = points.filter{ point =>
point.coordinates(coordinateIndex) >= median
}
val leftPoints: List[Point] = points.filter{ point =>
point.coordinates(coordinateIndex) < median
}
println(s"branching at index: $coordinateIndex")
println(s"right has: ${rightPoints.length} elements: ${rightPoints}")
println(s"left has: ${leftPoints.length} elements: ${leftPoints}")
Branch(
right = if (canSplitFurther(rightPoints)) splitBasedOnCoordinate(coordinateIndex + 1, rightPoints) else Leaf(rightPoints),
left = if (canSplitFurther(leftPoints)) splitBasedOnCoordinate(coordinateIndex + 1, leftPoints) else Leaf(leftPoints),
decide = _.coordinates(coordinateIndex) >= median
)
}
else
Leaf(points)
}
splitBasedOnCoordinate(0, points)
}
}
}
object DemoApp extends App {
import Demo.Node
import Demo.Point
val trainedData: Node = Node.trainData(
List(
Point.in2D(1, 9),
Point.in2D(2, 3),
Point.in2D(4, 1),
Point.in2D(3, 7),
Point.in2D(5, 4),
Point.in2D(6, 8),
Point.in2D(7, 2),
Point.in2D(8, 8),
Point.in2D(7, 9),
Point.in2D(9, 6)
)
)
println(trainedData)
val point = Point.in2D(2, 3)
val canTheTrainedDataFindOriginalPoints: Boolean = Node.navigateNodeAndFindNearestPoint(trainedData, point) == point
val canTheTrainedDataFindOriginalPoints2: Boolean = Node.navigateNodeAndFindNearestPoint2(trainedData, point) == point
println(canTheTrainedDataFindOriginalPoints)
println(canTheTrainedDataFindOriginalPoints2)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment