Skip to content

Instantly share code, notes, and snippets.

@film42
Last active August 29, 2015 14:16
Show Gist options
  • Save film42/2578ee53f8bc071bd144 to your computer and use it in GitHub Desktop.
Save film42/2578ee53f8bc071bd144 to your computer and use it in GitHub Desktop.
package id3
/**
* Created by: film42 on: 3/6/15.
*/
object DecisionTreeUtils {
type Vector[A] = List[A]
type Matrix[A] = List[Vector[A]]
def playBallDataset: (Vector[String], Matrix[String]) = {
val names = List("Outlook", "Temperature", "Humidity", "Wind", "PlayBall")
val matrix = List(
List("Sunny", "Hot", "High", "Weak", "No"),
List("Sunny", "Hot", "High", "Strong", "No"),
List("Overcast", "Hot", "High", "Weak", "Yes"),
List("Rain", "Mild", "High", "Weak", "Yes"),
List("Rain", "Cool", "Normal", "Weak", "Yes"),
List("Rain", "Cool", "Normal", "Strong", "No"),
List("Overcast", "Cool", "Normal", "Strong", "Yes"),
List("Sunny", "Mild", "High", "Weak", "No"),
List("Sunny", "Cool", "Normal", "Weak", "Yes"),
List("Rain", "Mild", "Normal", "Weak", "Yes"),
List("Sunny", "Mild", "Normal", "Strong", "Yes"),
List("Overcast", "Mild", "High", "Strong", "Yes"),
List("Overcast", "Hot", "Normal", "Weak", "Yes"),
List("Rain", "Mild", "High", "Strong", "No")
)
(names, matrix)
}
def removeAt[A](list: List[A], n: Int): List[A] =
(list take n) ++ (list drop (n + 1))
def log2(x: Double) =
math.log(x) / math.log(2)
// Unique amount of labels given some column
def unique[A](matrix: Matrix[A], n: Int): Int =
matrix.map(_.last).toSet.size
def unique[A](column: Vector[A]): Int =
unique(List(column).transpose, 0)
def merge[A, B](coll: Iterable[Map[A, B]]): Map[A,B] =
coll.foldLeft(Map[A,B]())(_ ++ _)
def argmax[A](row: Vector[A]): A =
row.groupBy(identity).map(x => (x._1, x._2.size)).maxBy(_._2)._1
def occurrenceMap[A](vector: Vector[A]): Map[A, Int] = {
val frequencyMap = vector.toSet.map { item: A =>
Map(item -> vector.count(_ == item))
}
merge(frequencyMap)
}
def probabilityOfOccurrence[A](vector: Vector[A]): Map[A, Double] = {
val probabilityMap = vector.toSet.map { item: A =>
Map(item -> vector.count(_ == item) / vector.size.toDouble)
}
merge(probabilityMap)
}
def infoS[A, B](features: Matrix[A], labels: Vector[B]): Double = {
val priors = probabilityOfOccurrence(labels)
val initialEntropy = priors.values.map { probability =>
log2(probability) * probability
}
initialEntropy.sum * -1
}
//
// Ex: Map(Rain -> List(0.6, 0.4), Sunny -> List(0.6, 0.4), Overcast -> List(1.0))
//
def infoProbabilities[A, B](features: Matrix[A], labels: Vector[B], index: Int): Map[A, Iterable[Double]] = {
val feature = features.transpose.toList(index)
val featureLabelPairs = feature.zip(labels)
merge(featureLabelPairs.groupBy(_._1).map {
case (name, list) =>
val specificLabelCounts = list.map(_._2)
val probabilities = occurrenceMap(specificLabelCounts).map {
case (label, frequency) =>
frequency / list.size.toDouble
}
Map(name -> probabilities)
})
}
def infoSA[A, B](features: Matrix[A], labels: Vector[B], index: Int): Double = {
val feature = features.transpose.toList(index)
val featurePriors = probabilityOfOccurrence(feature)
val probabilities = infoProbabilities(features, labels, index)
//
// Ex: List(0.3467680694480959, 0.0, 0.3467680694480959)
//
val entropyForAttributes = featurePriors.map {
case (attribute, prior) =>
prior * probabilities(attribute).map { labelProbability =>
labelProbability * log2(labelProbability) * -1
}.sum
}
entropyForAttributes.sum
}
def infoGain[A, B](matrix: Matrix[A], labels: Vector[B], index: Int): Double = {
infoS(matrix, labels) - infoSA(matrix, labels, index)
}
//
// Example:
// Tree(Outlook,
// Map(Sunny ->
// Tree(Humidity,
// Map(High -> Node(No),
// Normal -> Node(Yes))),
// Overcast ->
// Node(Yes),
// Rain ->
// Tree(Wind,
// Map(Weak -> Node(Yes),
// Strong -> Node(No)))))
//
def generateTree[A, B](matrix: Matrix[A], labels: Vector[B], featureNames: Vector[String], probability: Double = 1.0): TreeElement = {
val defaultLabel = argmax(labels)
// Check for an empty branch
if(matrix.isEmpty || matrix.transpose.isEmpty) {
Node(defaultLabel, probability)
}
// Check for a pure node (1 class)
else if(unique(labels) == 1) {
Node(labels.head, probability)
}
// Choose which feature (column) is best
else {
// Ex: Map(0 -> 0.247, 1 -> 0.0292, 2 -> 0.152, 3 -> 0.0481)
val gain = merge(Range(0, matrix.transpose.size).map { x: Int =>
Map(x -> infoGain(matrix, labels, x))
})
val bestFeatureIndex = gain.maxBy(_._2)._1
val attributeSet = matrix.transpose.toList(bestFeatureIndex).toSet
val featuresNamesPrime = removeAt(featureNames, bestFeatureIndex)
val children = attributeSet.foldLeft(Map[A, TreeElement]()) { (acc, attribute) =>
val dataPrime = Range(0, labels.size).map { index =>
val featureRow = matrix(index)
val label = labels(index)
if(featureRow(bestFeatureIndex) == attribute) {
Some((removeAt(featureRow, bestFeatureIndex), label))
} else {
None
}
}.flatten
val matrixPrime = dataPrime.map(_._1).toList
val labelsPrime = dataPrime.map(_._2).toList
val newProbability = matrixPrime.size / matrix.size.toDouble
acc + (attribute -> generateTree(matrixPrime, labelsPrime, featuresNamesPrime, newProbability))
}
val tree = Tree(featureNames(bestFeatureIndex), probability, children)
tree
}
}
trait TreeElement {
def probability: Double = 1.0
}
case class Tree[A, B](feature: A, override val probability: Double, children: Map[B, TreeElement]) extends TreeElement
case class Node[A](label: A, override val probability: Double) extends TreeElement
class DecisionTree(matrix: java.util.ArrayList[Array[Double]],
labels: java.util.ArrayList[Array[Double]],
featureNames: java.util.ArrayList[String]) {
import scala.collection.JavaConverters._
val matrixPrime = matrix.asScala.toList.map(_.toList)
val labelsPrime = labels.asScala.toList.transpose.head
val featureNamesPrime = featureNames.asScala.toList
val tree = generateTree(matrixPrime, labelsPrime, featureNamesPrime)
println(tree)
def classify(featureRow: Array[Double]): Double = {
@annotation.tailrec
def classifyRecur(treeElement: TreeElement): TreeElement = treeElement match {
case Tree(featureName, probability, children) =>
val featureIndex = featureNamesPrime.indexOf(featureName)
val attribute = featureRow(featureIndex)
val childTree = if(children.keys.toList.indexOf(attribute) < 0) {
val probabilities = children.map(_._2.probability).toList
val bestIndex = probabilities.indexOf(probabilities.max)
val attributeKey = children.keys.toList(bestIndex)
println(s"Not Found: $featureIndex | $featureName")
children(attributeKey)
} else {
children(attribute)
}
classifyRecur(childTree)
case node @ Node(label, probability) => node
}
classifyRecur(tree) match {
case Node(label, probability) => label.asInstanceOf[Double]
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment