Last active
August 29, 2015 14:16
-
-
Save film42/2578ee53f8bc071bd144 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package id3 | |
/** | |
* Created by: film42 on: 3/6/15. | |
*/ | |
object DecisionTreeUtils { | |
type Vector[A] = List[A] | |
type Matrix[A] = List[Vector[A]] | |
def playBallDataset: (Vector[String], Matrix[String]) = { | |
val names = List("Outlook", "Temperature", "Humidity", "Wind", "PlayBall") | |
val matrix = List( | |
List("Sunny", "Hot", "High", "Weak", "No"), | |
List("Sunny", "Hot", "High", "Strong", "No"), | |
List("Overcast", "Hot", "High", "Weak", "Yes"), | |
List("Rain", "Mild", "High", "Weak", "Yes"), | |
List("Rain", "Cool", "Normal", "Weak", "Yes"), | |
List("Rain", "Cool", "Normal", "Strong", "No"), | |
List("Overcast", "Cool", "Normal", "Strong", "Yes"), | |
List("Sunny", "Mild", "High", "Weak", "No"), | |
List("Sunny", "Cool", "Normal", "Weak", "Yes"), | |
List("Rain", "Mild", "Normal", "Weak", "Yes"), | |
List("Sunny", "Mild", "Normal", "Strong", "Yes"), | |
List("Overcast", "Mild", "High", "Strong", "Yes"), | |
List("Overcast", "Hot", "Normal", "Weak", "Yes"), | |
List("Rain", "Mild", "High", "Strong", "No") | |
) | |
(names, matrix) | |
} | |
def removeAt[A](list: List[A], n: Int): List[A] = | |
(list take n) ++ (list drop (n + 1)) | |
def log2(x: Double) = | |
math.log(x) / math.log(2) | |
// Unique amount of labels given some column | |
def unique[A](matrix: Matrix[A], n: Int): Int = | |
matrix.map(_.last).toSet.size | |
def unique[A](column: Vector[A]): Int = | |
unique(List(column).transpose, 0) | |
def merge[A, B](coll: Iterable[Map[A, B]]): Map[A,B] = | |
coll.foldLeft(Map[A,B]())(_ ++ _) | |
def argmax[A](row: Vector[A]): A = | |
row.groupBy(identity).map(x => (x._1, x._2.size)).maxBy(_._2)._1 | |
def occurrenceMap[A](vector: Vector[A]): Map[A, Int] = { | |
val frequencyMap = vector.toSet.map { item: A => | |
Map(item -> vector.count(_ == item)) | |
} | |
merge(frequencyMap) | |
} | |
def probabilityOfOccurrence[A](vector: Vector[A]): Map[A, Double] = { | |
val probabilityMap = vector.toSet.map { item: A => | |
Map(item -> vector.count(_ == item) / vector.size.toDouble) | |
} | |
merge(probabilityMap) | |
} | |
def infoS[A, B](features: Matrix[A], labels: Vector[B]): Double = { | |
val priors = probabilityOfOccurrence(labels) | |
val initialEntropy = priors.values.map { probability => | |
log2(probability) * probability | |
} | |
initialEntropy.sum * -1 | |
} | |
// | |
// Ex: Map(Rain -> List(0.6, 0.4), Sunny -> List(0.6, 0.4), Overcast -> List(1.0)) | |
// | |
def infoProbabilities[A, B](features: Matrix[A], labels: Vector[B], index: Int): Map[A, Iterable[Double]] = { | |
val feature = features.transpose.toList(index) | |
val featureLabelPairs = feature.zip(labels) | |
merge(featureLabelPairs.groupBy(_._1).map { | |
case (name, list) => | |
val specificLabelCounts = list.map(_._2) | |
val probabilities = occurrenceMap(specificLabelCounts).map { | |
case (label, frequency) => | |
frequency / list.size.toDouble | |
} | |
Map(name -> probabilities) | |
}) | |
} | |
def infoSA[A, B](features: Matrix[A], labels: Vector[B], index: Int): Double = { | |
val feature = features.transpose.toList(index) | |
val featurePriors = probabilityOfOccurrence(feature) | |
val probabilities = infoProbabilities(features, labels, index) | |
// | |
// Ex: List(0.3467680694480959, 0.0, 0.3467680694480959) | |
// | |
val entropyForAttributes = featurePriors.map { | |
case (attribute, prior) => | |
prior * probabilities(attribute).map { labelProbability => | |
labelProbability * log2(labelProbability) * -1 | |
}.sum | |
} | |
entropyForAttributes.sum | |
} | |
def infoGain[A, B](matrix: Matrix[A], labels: Vector[B], index: Int): Double = { | |
infoS(matrix, labels) - infoSA(matrix, labels, index) | |
} | |
// | |
// Example: | |
// Tree(Outlook, | |
// Map(Sunny -> | |
// Tree(Humidity, | |
// Map(High -> Node(No), | |
// Normal -> Node(Yes))), | |
// Overcast -> | |
// Node(Yes), | |
// Rain -> | |
// Tree(Wind, | |
// Map(Weak -> Node(Yes), | |
// Strong -> Node(No))))) | |
// | |
def generateTree[A, B](matrix: Matrix[A], labels: Vector[B], featureNames: Vector[String], probability: Double = 1.0): TreeElement = { | |
val defaultLabel = argmax(labels) | |
// Check for an empty branch | |
if(matrix.isEmpty || matrix.transpose.isEmpty) { | |
Node(defaultLabel, probability) | |
} | |
// Check for a pure node (1 class) | |
else if(unique(labels) == 1) { | |
Node(labels.head, probability) | |
} | |
// Choose which feature (column) is best | |
else { | |
// Ex: Map(0 -> 0.247, 1 -> 0.0292, 2 -> 0.152, 3 -> 0.0481) | |
val gain = merge(Range(0, matrix.transpose.size).map { x: Int => | |
Map(x -> infoGain(matrix, labels, x)) | |
}) | |
val bestFeatureIndex = gain.maxBy(_._2)._1 | |
val attributeSet = matrix.transpose.toList(bestFeatureIndex).toSet | |
val featuresNamesPrime = removeAt(featureNames, bestFeatureIndex) | |
val children = attributeSet.foldLeft(Map[A, TreeElement]()) { (acc, attribute) => | |
val dataPrime = Range(0, labels.size).map { index => | |
val featureRow = matrix(index) | |
val label = labels(index) | |
if(featureRow(bestFeatureIndex) == attribute) { | |
Some((removeAt(featureRow, bestFeatureIndex), label)) | |
} else { | |
None | |
} | |
}.flatten | |
val matrixPrime = dataPrime.map(_._1).toList | |
val labelsPrime = dataPrime.map(_._2).toList | |
val newProbability = matrixPrime.size / matrix.size.toDouble | |
acc + (attribute -> generateTree(matrixPrime, labelsPrime, featuresNamesPrime, newProbability)) | |
} | |
val tree = Tree(featureNames(bestFeatureIndex), probability, children) | |
tree | |
} | |
} | |
trait TreeElement { | |
def probability: Double = 1.0 | |
} | |
case class Tree[A, B](feature: A, override val probability: Double, children: Map[B, TreeElement]) extends TreeElement | |
case class Node[A](label: A, override val probability: Double) extends TreeElement | |
class DecisionTree(matrix: java.util.ArrayList[Array[Double]], | |
labels: java.util.ArrayList[Array[Double]], | |
featureNames: java.util.ArrayList[String]) { | |
import scala.collection.JavaConverters._ | |
val matrixPrime = matrix.asScala.toList.map(_.toList) | |
val labelsPrime = labels.asScala.toList.transpose.head | |
val featureNamesPrime = featureNames.asScala.toList | |
val tree = generateTree(matrixPrime, labelsPrime, featureNamesPrime) | |
println(tree) | |
def classify(featureRow: Array[Double]): Double = { | |
@annotation.tailrec | |
def classifyRecur(treeElement: TreeElement): TreeElement = treeElement match { | |
case Tree(featureName, probability, children) => | |
val featureIndex = featureNamesPrime.indexOf(featureName) | |
val attribute = featureRow(featureIndex) | |
val childTree = if(children.keys.toList.indexOf(attribute) < 0) { | |
val probabilities = children.map(_._2.probability).toList | |
val bestIndex = probabilities.indexOf(probabilities.max) | |
val attributeKey = children.keys.toList(bestIndex) | |
println(s"Not Found: $featureIndex | $featureName") | |
children(attributeKey) | |
} else { | |
children(attribute) | |
} | |
classifyRecur(childTree) | |
case node @ Node(label, probability) => node | |
} | |
classifyRecur(tree) match { | |
case Node(label, probability) => label.asInstanceOf[Double] | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment