Skip to content

Instantly share code, notes, and snippets.

@mutsune
Created September 21, 2018 15:15
Show Gist options
  • Save mutsune/ee4a6dcb836c5cb2669ca3f350644077 to your computer and use it in GitHub Desktop.
Save mutsune/ee4a6dcb836c5cb2669ca3f350644077 to your computer and use it in GitHub Desktop.
import java.io.File
object TrainSampleWithScala extends App {
def train() {
val parameter: Parameter = new Parameter(SolverType.AVERAGE, 10)
val model: PerceptronModel = PerceptronModel.train(docs, parameter)
return model
}
def predict(model: PerceptronModel) {
val predictor: Predictor = new Predictor(model)
val predictedEventClusters: Map[Doc, Array[EventCoreference]] = docs.map { d ->
d -> predictor.predict(d)
}
return predictedEventClusters
}
}
case class CoreferenceFeatureExtractor(docs: Array[Doc]) {
def correctTrees(): Array[Tree] = docs.map { d =>
val features = mutable.Map.empty[String, Double]
docs.map { d =>
val tree =
featureExtraction(d, features)
}
}
}
case class PerceptronModel(weights: Map[String, Double], allFeatures: Map[String, Double]) {
def interpolateFeatures(features: Map[String, Double]): Map[String, Double] = {
allFeatures.keys.map(k =>
features.get(_) match {
case Some(v) => k -> v
case _ => k -> 0.0
}).toMap
}
}
object PerceptronModel {
def train(docs: Array[Doc], parameter: Parameter) = {
val featureExtractor = CoreferenceFeatureExtractor(docs)
val correctTrees = featureExtractor.correctTrees
val features = featureExtractor.features
val model: PerceptronModel = new PerceptronModel(new Array[Double](features.size), features)
val weights = mutable.Map.empty[String, Double]
val averageWeights = mutable.Map.empty[String, Double]
for {
n <- parameter.numberOfIterations
doc <- docs
} {
val goldTree = doc.tree
val predictedTree: Tree = model.predict(doc)
if (goldTree != predictedTree) {
val features = goldTree.features
val incorrectFeatures = predictedTree.features
features.keys.foreach { k =>
weights(k) = weights.getOrElseUpdate(k, 0) + features(k) - incorrectFeatures(k)
}
}
weights.keys.foreach { k =>
averageWeights(k) += weights(k)
}
}
if (parameter.solverType == SolverType.AVERAGE) {
val times: Int = parameter.numberOfIterations * docs.length
averageWeights.keys.foreach { k =>
weight(k) = averageWeights(k) / times
}
}
model
}
/*
Given a set of weights and a possible label, return the score
w * f(<a, m>), where m is a mention and a is an antecedent of m
*/
def score(weights: Map[String, Double], features: Map[String, Double]): Double =
weights.keys.foldLeft(0)((z, k) => z + (weights(k) * features(k)))
}
case class Predictor(model: PerceptronModel) {
// Given an Instance, return the class label
def predict(doc: Doc): Int = {
val possibleTrees: Array[Tree] = CoreferenceFeatureExtractor.possibleTrees(doc)
possibleTrees.maxBy(PerceptronModel.score(model.weights, _.features))
}
}
case class Tree(eventCoreferences: Array[EventCoreference], features: Map[String, Double]) {
private[this] val hashCodeValue = hashCode()
override def hashCode = eventCoreferences
// Summation of events of each cluster
.map(_.events.foldLeft(1)((z, n) => z * (n.hashCode + 31)))
// Summation of clusters
.foldLeft(1)((z, n) => z * (n + 31))
override def equals(other: Any) = other match {
case that: Tree =>
(that canEqual this) && (this.hashCodeValue == that.hashCodeValue)
case _ => false
}
// TODO: need override?
def canEqual(other: Any) = other.isInstanceOf[Tree]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment