Skip to content

Instantly share code, notes, and snippets.

@jsuereth
Forked from mpitid/JTrain.java
Created February 23, 2012 20:16
Show Gist options
  • Save jsuereth/1894852 to your computer and use it in GitHub Desktop.
Save jsuereth/1894852 to your computer and use it in GitHub Desktop.
Implementation of a simple training program in Python, Java and Scala
// vim: set ts=4 sw=4 et:
import java.io.FileReader;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.Iterator;
import java.util.Scanner;
import java.util.HashMap;
import java.util.Map;
public class JTrain {
/** Keep track of feature frequencies. */
static HashMap<String, int[]> counter = new HashMap<String, int[]>();
public static void main(String[] args) throws IOException {
Scanner train = null;
try {
train = new Scanner(
new BufferedReader(
new FileReader(args[0])));
for (String token: Iter.string(train))
if (!token.isEmpty()) {
boolean upper = Character.isUpperCase(token.charAt(0));
for (String feature: features(token.toLowerCase()))
count(feature, upper);
}
} finally {
if (train != null)
train.close();
}
for (Map.Entry<String, int[]> entry: counter.entrySet())
System.out.printf("%s %d %d\n",
entry.getKey(),
entry.getValue()[0],
entry.getValue()[1]);
}
public static void count(String key, boolean upper) {
int index = upper ? 1 : 0;
if (!counter.containsKey(key)) {
int[] v = {0, 0};
counter.put(key, v);
}
//counter.get(key)[index]++;
counter.get(key)[index] = roughCount(counter.get(key)[index]);
}
public static String[] features(String word) {
String[] fs = {
word,
substring(word, 0, 2),
substring(word, 0, 3),
substring(word, -2, -1),
substring(word, -3, -1),
};
return fs;
}
/**
* Substring indexing similar to Python slices.
*/
public static String substring(String s, int i, int j) {
int n = s.length();
i = i < 0 ? n + i: i;
j = j < 0 ? n + 1 + j : j;
return s.substring(Math.max(i, 0), Math.max(0, Math.min(j, n)));
}
/**
* Probabilistic counting, catch every log(N) occurrences.
*/
public static int roughCount(int value) {
double n = Math.random();
return (n < (1.0 / (1 << value))) ? value + 1 : value;
}
/**
* Iterable wrapper around an Iterator object.
*/
static class Iter<T> implements Iterable<T> {
private Iterator<T> t = null;
Iter (Iterator<T> t) {
assert(t != null);
this.t = t;
}
public Iterator<T> iterator() { return t; }
public static Iterable<String> string(Iterator<String> t) { return new Iter<String>(t); }
}
}
object STrain {
val rnd = new util.Random()
// Note: In 2.9.x you probably want to use a java.util.HashMap instead of a .
val map = new collection.mutable.HashMap[String, Array[Int]]
def main(args: Array[String]) {
if (args.length > 0) {
// Build the map of feature frequencies.
train(args(0))
printResults()
} else
Console.err.println("usage: STrain <filename>")
}
def printResults(): Unit =
for ((key, Array(l,u)) <- map)
println("%s %d %d".format(key, l, u))
def train(filename: String): Unit =
for {
line <- io.Source.fromFile(args(0)).getLines
token <- line.split("[ \t]+")
if !token.isEmpty
upper = token(0).isUpper
feature <- features(token.toLowerCase)
} count(feature, upper)
def count(key: String, upper: Boolean) = {
val idx = if (upper) 1 else 0
if (!map.contains(key))
map += (key -> Array(0, 0))
map(key)(idx) = roughCount(map(key)(idx))
}
def roughCount(value: Int) =
if (rnd.nextDouble() < 1.0 / (1 << value)) value + 1
else value
def features(word: String) = Array(
word,
word take 2,
word take 3,
word takeRight 2,
work takeRight 3)
)
}
#!/usr/bin/env python
import sys, random
from collections import defaultdict
def main(args):
if len(args) < 2:
print "usage: %s <filename>" % args[0]
return 1
counts = defaultdict(lambda: [0, 0])
def count(f, i):
counts[f][i] = rough_count(counts[f][i])
with open(args[1]) as f:
for token in (t for line in f for t in line.split() if t):
upper = token[0].isupper()
for feature in features(token.lower()):
count(feature, upper)
for k, (l, u) in counts.iteritems():
print "%s %d %d" % (k, l, u)
return 0
def features(word):
return word, word[-3:], word[:3], word[-2:], word[:2]
def rough_count(val):
return val + 1 if random.random() < 1.0 / (1 << val)\
else val
if __name__ == '__main__':
sys.exit(main(sys.argv))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment