Skip to content

Instantly share code, notes, and snippets.

@schaunwheeler
Last active April 23, 2019 13:43
Show Gist options
  • Save schaunwheeler/2f918b0a6b4d50e3721269505fb47105 to your computer and use it in GitHub Desktop.
Save schaunwheeler/2f918b0a6b4d50e3721269505fb47105 to your computer and use it in GitHub Desktop.
An example of using Scala to call the predict function from a Scikit-Learn RandomForestRegressor
import rapture.json.jsonBackends.jawn._
import rapture.json.Json
import scala.annotation.tailrec
case class RandomForestTree(
treeId: Int,
undefinedIndex: Int,
features: Array[Int],
thresholds: Array[Double],
childrenLeft: Array[Int],
childrenRight: Array[Int],
values: Array[Double]
) extends Serializable {
def walkTree(inputs: Array[Double]): Double = {
@tailrec
def subwalk(node: Int): Int = {
if (features(node) == undefinedIndex) {
node
} else {
val featureIndex = features(node)
val threshold = thresholds(node)
val inputFeatureValue = inputs(featureIndex)
if (inputFeatureValue <= threshold) {
subwalk(childrenLeft(node))
} else {
subwalk(childrenRight(node))
}
}
}
val finalNode = subwalk(0)
values(finalNode)
}
}
case class RandomForestRegressor(name: String, trees: Array[RandomForestTree]) {
def transform(input: Array[Double]): Double = {
val treeResults = trees.map(_.walkTree(input))
val score = treeResults.sum / trees.length
score
}
}
object RandomForestTree {
def fromJSON(json: Json): RandomForestTree = {
RandomForestTree(
json.i.as[Int],
json.tree_undefined.as[Int],
json.features.as[Array[Int]],
json.thresholds.as[Array[Double]],
json.children_left.as[Array[Int]],
json.children_right.as[Array[Int]],
json.values.as[Array[Double]]
)
}
}
object RandomForestRegressor {
def fromJSON(json: Json): RandomForestRegressor = {
import rapture.json.jsonBackends.jawn._
RandomForestRegressor(
name = json.name.as[String],
trees = json.trees.as[Array[Json]].map(RandomForestTree.fromJSON)
)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment