Created
April 5, 2013 11:58
-
-
Save thomasjungblut/5318761 to your computer and use it in GitHub Desktop.
simple pos tagger using HMM with ~ 91.82% accuracy with a small trainingset of 70k words and 10k test words.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package de.jungblut.ml; | |
import java.io.BufferedReader; | |
import java.io.FileNotFoundException; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.regex.Pattern; | |
import com.google.common.collect.BiMap; | |
import com.google.common.collect.HashBiMap; | |
import com.google.common.collect.Lists; | |
import de.jungblut.math.DoubleVector; | |
import de.jungblut.math.dense.DenseDoubleVector; | |
import de.jungblut.math.sparse.SparseDoubleVector; | |
import de.jungblut.math.tuple.Tuple; | |
import de.jungblut.ner.SequenceFeatureExtractor; | |
import de.jungblut.ner.SparseFeatureExtractorHelper; | |
import de.jungblut.nlp.HMM; | |
public class POSTagger { | |
static BiMap<String, Integer> indexTag = HashBiMap.create(); | |
static int id = 0; | |
static { | |
indexTag.put("SENTENCE_BEGIN", id++); | |
} | |
static boolean printMisprediction = false; | |
public static void main(String[] args) throws IOException { | |
String inputFile = "files/pos/training.pos"; | |
String inputDevelopmentFile = "files/pos/development.pos"; | |
List<String> words = new ArrayList<>(); | |
List<Integer> labels = new ArrayList<>(); | |
read(inputFile, words, labels); | |
List<String> developmentWords = new ArrayList<>(); | |
List<Integer> developmentLabels = new ArrayList<>(); | |
read(inputDevelopmentFile, developmentWords, developmentLabels); | |
SparseFeatureExtractorHelper<String> extractorHelper = new SparseFeatureExtractorHelper<>( | |
words, labels, new POSExtractor()); | |
Tuple<DoubleVector[], DenseDoubleVector[]> vectorize = extractorHelper | |
.vectorize(); | |
String[] dictionary = extractorHelper.getDictionary(); | |
DoubleVector[] features = vectorize.getFirst(); | |
DenseDoubleVector[] state = vectorize.getSecond(); | |
HMM hmm = new HMM(dictionary.length, indexTag.size()); | |
hmm.trainSupervised(features, state); | |
Tuple<DoubleVector[], DenseDoubleVector[]> vectorizeAdditionals = extractorHelper | |
.vectorizeAdditionals(developmentWords, developmentLabels); | |
DoubleVector[] testFeatures = vectorizeAdditionals.getFirst(); | |
DenseDoubleVector[] testLabels = vectorizeAdditionals.getSecond(); | |
int correct = 0; | |
DoubleVector lastPrediction = new SparseDoubleVector(indexTag.size()); | |
lastPrediction.set(0, 1d); | |
for (int i = 0; i < testFeatures.length; i++) { | |
DoubleVector feat = testFeatures[i]; | |
DenseDoubleVector outcome = testLabels[i]; | |
DoubleVector predicted = hmm.predict(feat, lastPrediction); | |
int predictedHiddenState = predicted.maxIndex(); | |
if (predictedHiddenState == outcome.maxIndex()) { | |
correct++; | |
} else if (printMisprediction) { | |
System.out | |
.println("\"" + developmentWords.get(i) + "\" -> Predicted: \"" | |
+ indexTag.inverse().get(predictedHiddenState) | |
+ "\" But should be: " | |
+ indexTag.inverse().get(outcome.maxIndex())); | |
} | |
lastPrediction = predicted; | |
} | |
System.out.println(correct + "/" + testFeatures.length + "= " | |
+ (correct / (double) testFeatures.length * 100d) + "% Accuracy."); | |
} | |
private static void read(String inputFile, List<String> words, | |
List<Integer> labels) throws IOException, FileNotFoundException { | |
try (BufferedReader br = new BufferedReader(new FileReader(inputFile))) { | |
String line; | |
while ((line = br.readLine()) != null) { | |
String[] split = line.split("\t"); | |
if (split.length == 0) { | |
words.add("SENTENCE_BEGIN"); | |
labels.add(indexTag.get("SENTENCE_BEGIN")); | |
} else { | |
words.add(split[0]); | |
String tag = split[1]; | |
if (!indexTag.containsKey(tag)) { | |
indexTag.put(tag, id++); | |
} | |
labels.add(indexTag.get(tag)); | |
} | |
} | |
} | |
} | |
static class POSExtractor implements SequenceFeatureExtractor<String> { | |
private static final int SUFFIX_LENGTH = 4; | |
private static final Pattern punct = Pattern | |
.compile("[!#%*+;,/<=>?@^_`{|}~]"); | |
@Override | |
public List<String> computeFeatures(List<String> words, int prevLabel, | |
int position) { | |
ArrayList<String> features = Lists.newArrayList(); | |
String word = words.get(position); | |
features.add("current=" + word); | |
features.add("prevlabel=" + prevLabel); | |
if (position > 0) { | |
features.add("prev=" + words.get(position - 1)); | |
} | |
if (position < words.size() - 1) { | |
features.add("next=" + words.get(position + 1)); | |
} | |
if (word.indexOf('-') != -1) { | |
features.add("hyphen"); | |
} | |
if (word.equals("...")) { | |
features.add("threedots"); | |
} else if (punct.matcher(word).find()) { | |
features.add("punct"); | |
} | |
if (word.indexOf('&') != -1) { | |
features.add("amp"); | |
} | |
if (word.indexOf('$') != -1) { | |
features.add("curr"); | |
} | |
if (word.indexOf('(') != -1) { | |
features.add("leftbrace"); | |
} | |
if (word.indexOf(')') != -1) { | |
features.add("rightbrace"); | |
} | |
if (word.indexOf('\'') != -1) { | |
features.add("singlequote"); | |
} | |
if (word.indexOf('"') != -1) { | |
features.add("doublequote"); | |
} | |
String[] prefs = getPrefixes(word); | |
for (int i = 0; i < prefs.length; i++) { | |
features.add("pre=" + prefs[i]); | |
} | |
String[] suffs = getSuffixes(word); | |
for (int i = 0; i < suffs.length; i++) { | |
features.add("suf=" + suffs[i]); | |
} | |
return features; | |
} | |
protected static String[] getSuffixes(String lex) { | |
String[] suffs = new String[SUFFIX_LENGTH]; | |
for (int li = 0, ll = SUFFIX_LENGTH; li < ll; li++) { | |
suffs[li] = lex.substring(Math.max(lex.length() - li - 1, 0)); | |
} | |
return suffs; | |
} | |
protected static String[] getPrefixes(String lex) { | |
String[] prefs = new String[SUFFIX_LENGTH]; | |
for (int li = 0, ll = SUFFIX_LENGTH; li < ll; li++) { | |
prefs[li] = lex.substring(0, Math.min(li + 1, lex.length())); | |
} | |
return prefs; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment