Last active
April 23, 2019 13:43
-
-
Save schaunwheeler/2f918b0a6b4d50e3721269505fb47105 to your computer and use it in GitHub Desktop.
An example of using Scala to call the predict function from a Scikit-Learn RandomForestRegressor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import rapture.json.jsonBackends.jawn._ | |
import rapture.json.Json | |
import scala.annotation.tailrec | |
case class RandomForestTree( | |
treeId: Int, | |
undefinedIndex: Int, | |
features: Array[Int], | |
thresholds: Array[Double], | |
childrenLeft: Array[Int], | |
childrenRight: Array[Int], | |
values: Array[Double] | |
) extends Serializable { | |
def walkTree(inputs: Array[Double]): Double = { | |
@tailrec | |
def subwalk(node: Int): Int = { | |
if (features(node) == undefinedIndex) { | |
node | |
} else { | |
val featureIndex = features(node) | |
val threshold = thresholds(node) | |
val inputFeatureValue = inputs(featureIndex) | |
if (inputFeatureValue <= threshold) { | |
subwalk(childrenLeft(node)) | |
} else { | |
subwalk(childrenRight(node)) | |
} | |
} | |
} | |
val finalNode = subwalk(0) | |
values(finalNode) | |
} | |
} | |
case class RandomForestRegressor(name: String, trees: Array[RandomForestTree]) { | |
def transform(input: Array[Double]): Double = { | |
val treeResults = trees.map(_.walkTree(input)) | |
val score = treeResults.sum / trees.length | |
score | |
} | |
} | |
object RandomForestTree { | |
def fromJSON(json: Json): RandomForestTree = { | |
RandomForestTree( | |
json.i.as[Int], | |
json.tree_undefined.as[Int], | |
json.features.as[Array[Int]], | |
json.thresholds.as[Array[Double]], | |
json.children_left.as[Array[Int]], | |
json.children_right.as[Array[Int]], | |
json.values.as[Array[Double]] | |
) | |
} | |
} | |
object RandomForestRegressor { | |
def fromJSON(json: Json): RandomForestRegressor = { | |
import rapture.json.jsonBackends.jawn._ | |
RandomForestRegressor( | |
name = json.name.as[String], | |
trees = json.trees.as[Array[Json]].map(RandomForestTree.fromJSON) | |
) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment