Last active
January 17, 2020 01:05
-
-
Save ShahOdin/61e9761d1fc44e2401c21b1e5ab8f4a8 to your computer and use it in GitHub Desktop.
kdtree implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object Demo { | |
def getMedian(ints:List[Int]): Int = ints.sorted.apply(ints.length / 2) | |
//sorry no shapless here. all these lists have the same length. in two dimensions they will have length two. | |
case class Point(coordinates: List[Int]) | |
object Point{ | |
def in2D(x: Int, y: Int): Point = Point(List(x,y)) | |
def estimateNearestPoint(points: List[Point], point: Point): Point = { | |
val indexOfPointWithSmallestDistance: Int = points | |
.map(_.coordinates.zip(point.coordinates)) | |
.map(_.map(pair =>pair._2 - pair._1)) | |
.map(_.sum) | |
.zipWithIndex | |
.minBy(_._1) | |
._2 | |
points(indexOfPointWithSmallestDistance) | |
} | |
} | |
import Point._ | |
sealed trait Node{ | |
def isBranch: Boolean //ugly function for porting to languages without pattern-matching | |
def points: List[Point] //ugly function for porting to languages without pattern-matching | |
def whichNodeToLookNext(point: Point): Node //ugly function for porting to languages without pattern-matching | |
} | |
object Node { | |
def navigateNodeAndFindNearestPoint(node: Node, point: Point): Point = node match { | |
case l: Leaf => estimateNearestPoint(l.points, point) | |
case b: Branch => navigateNodeAndFindNearestPoint(b.whichNodeToLookNext(point), point) | |
} | |
def navigateNodeAndFindNearestPoint2(node: Node, point: Point): Point = if(node.isBranch) { | |
navigateNodeAndFindNearestPoint2(node.whichNodeToLookNext(point), point) | |
} else { | |
estimateNearestPoint(node.points , point) | |
} | |
case class Branch(right: Node, left: Node, decide: Point => Boolean) extends Node { | |
def isBranch: Boolean = true | |
def whichNodeToLookNext(point: Point): Node = if(decide(point)) right else left | |
override def points: List[Point] = List() | |
} | |
case class Leaf(points: List[Point]) extends Node { | |
def isBranch: Boolean = false //fake impl | |
def whichNodeToLookNext(point: Point): Node = this //fake impl | |
} | |
val K = 2 | |
def canSplitFurther(points: List[Point]): Boolean = { | |
points.length >= math.pow(2, K) | |
} | |
def trainData(points: List[Point]): Node = { | |
assert(points.forall(_.coordinates.length == K)) | |
def splitBasedOnCoordinate(coordinateIndex: Int, points: List[Point]): Node = { | |
assert(coordinateIndex >= 0 && coordinateIndex < K) | |
if(canSplitFurther(points)) { | |
val cs: List[Int] = points.map{ | |
point => | |
point.coordinates(coordinateIndex) | |
} | |
val median = getMedian(cs) | |
val rightPoints: List[Point] = points.filter{ point => | |
point.coordinates(coordinateIndex) >= median | |
} | |
val leftPoints: List[Point] = points.filter{ point => | |
point.coordinates(coordinateIndex) < median | |
} | |
println(s"branching at index: $coordinateIndex") | |
println(s"right has: ${rightPoints.length} elements: ${rightPoints}") | |
println(s"left has: ${leftPoints.length} elements: ${leftPoints}") | |
Branch( | |
right = if (canSplitFurther(rightPoints)) splitBasedOnCoordinate(coordinateIndex + 1, rightPoints) else Leaf(rightPoints), | |
left = if (canSplitFurther(leftPoints)) splitBasedOnCoordinate(coordinateIndex + 1, leftPoints) else Leaf(leftPoints), | |
decide = _.coordinates(coordinateIndex) >= median | |
) | |
} | |
else | |
Leaf(points) | |
} | |
splitBasedOnCoordinate(0, points) | |
} | |
} | |
} | |
object DemoApp extends App { | |
import Demo.Node | |
import Demo.Point | |
val trainedData: Node = Node.trainData( | |
List( | |
Point.in2D(1, 9), | |
Point.in2D(2, 3), | |
Point.in2D(4, 1), | |
Point.in2D(3, 7), | |
Point.in2D(5, 4), | |
Point.in2D(6, 8), | |
Point.in2D(7, 2), | |
Point.in2D(8, 8), | |
Point.in2D(7, 9), | |
Point.in2D(9, 6) | |
) | |
) | |
println(trainedData) | |
val point = Point.in2D(2, 3) | |
val canTheTrainedDataFindOriginalPoints: Boolean = Node.navigateNodeAndFindNearestPoint(trainedData, point) == point | |
val canTheTrainedDataFindOriginalPoints2: Boolean = Node.navigateNodeAndFindNearestPoint2(trainedData, point) == point | |
println(canTheTrainedDataFindOriginalPoints) | |
println(canTheTrainedDataFindOriginalPoints2) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment