Created
September 21, 2018 15:15
-
-
Save mutsune/ee4a6dcb836c5cb2669ca3f350644077 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.File | |
object TrainSampleWithScala extends App { | |
def train() { | |
val parameter: Parameter = new Parameter(SolverType.AVERAGE, 10) | |
val model: PerceptronModel = PerceptronModel.train(docs, parameter) | |
return model | |
} | |
def predict(model: PerceptronModel) { | |
val predictor: Predictor = new Predictor(model) | |
val predictedEventClusters: Map[Doc, Array[EventCoreference]] = docs.map { d -> | |
d -> predictor.predict(d) | |
} | |
return predictedEventClusters | |
} | |
} | |
case class CoreferenceFeatureExtractor(docs: Array[Doc]) { | |
def correctTrees(): Array[Tree] = docs.map { d => | |
val features = mutable.Map.empty[String, Double] | |
docs.map { d => | |
val tree = | |
featureExtraction(d, features) | |
} | |
} | |
} | |
case class PerceptronModel(weights: Map[String, Double], allFeatures: Map[String, Double]) { | |
def interpolateFeatures(features: Map[String, Double]): Map[String, Double] = { | |
allFeatures.keys.map(k => | |
features.get(_) match { | |
case Some(v) => k -> v | |
case _ => k -> 0.0 | |
}).toMap | |
} | |
} | |
object PerceptronModel { | |
def train(docs: Array[Doc], parameter: Parameter) = { | |
val featureExtractor = CoreferenceFeatureExtractor(docs) | |
val correctTrees = featureExtractor.correctTrees | |
val features = featureExtractor.features | |
val model: PerceptronModel = new PerceptronModel(new Array[Double](features.size), features) | |
val weights = mutable.Map.empty[String, Double] | |
val averageWeights = mutable.Map.empty[String, Double] | |
for { | |
n <- parameter.numberOfIterations | |
doc <- docs | |
} { | |
val goldTree = doc.tree | |
val predictedTree: Tree = model.predict(doc) | |
if (goldTree != predictedTree) { | |
val features = goldTree.features | |
val incorrectFeatures = predictedTree.features | |
features.keys.foreach { k => | |
weights(k) = weights.getOrElseUpdate(k, 0) + features(k) - incorrectFeatures(k) | |
} | |
} | |
weights.keys.foreach { k => | |
averageWeights(k) += weights(k) | |
} | |
} | |
if (parameter.solverType == SolverType.AVERAGE) { | |
val times: Int = parameter.numberOfIterations * docs.length | |
averageWeights.keys.foreach { k => | |
weight(k) = averageWeights(k) / times | |
} | |
} | |
model | |
} | |
/* | |
Given a set of weights and a possible label, return the score | |
w * f(<a, m>), where m is a mention and a is an antecedent of m | |
*/ | |
def score(weights: Map[String, Double], features: Map[String, Double]): Double = | |
weights.keys.foldLeft(0)((z, k) => z + (weights(k) * features(k))) | |
} | |
case class Predictor(model: PerceptronModel) { | |
// Given an Instance, return the class label | |
def predict(doc: Doc): Int = { | |
val possibleTrees: Array[Tree] = CoreferenceFeatureExtractor.possibleTrees(doc) | |
possibleTrees.maxBy(PerceptronModel.score(model.weights, _.features)) | |
} | |
} | |
case class Tree(eventCoreferences: Array[EventCoreference], features: Map[String, Double]) { | |
private[this] val hashCodeValue = hashCode() | |
override def hashCode = eventCoreferences | |
// Summation of events of each cluster | |
.map(_.events.foldLeft(1)((z, n) => z * (n.hashCode + 31))) | |
// Summation of clusters | |
.foldLeft(1)((z, n) => z * (n + 31)) | |
override def equals(other: Any) = other match { | |
case that: Tree => | |
(that canEqual this) && (this.hashCodeValue == that.hashCodeValue) | |
case _ => false | |
} | |
// TODO: need override? | |
def canEqual(other: Any) = other.isInstanceOf[Tree] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment