Created
February 23, 2012 20:12
-
-
Save mpitid/1894835 to your computer and use it in GitHub Desktop.
Implementation of a simple training program in Python, Java and Scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// vim: set ts=4 sw=4 et: | |
import java.io.FileReader; | |
import java.io.BufferedReader; | |
import java.io.IOException; | |
import java.util.Iterator; | |
import java.util.Scanner; | |
import java.util.HashMap; | |
import java.util.Map; | |
public class JTrain { | |
/** Keep track of feature frequencies. */ | |
static HashMap<String, int[]> counter = new HashMap<String, int[]>(); | |
public static void main(String[] args) throws IOException { | |
Scanner train = null; | |
try { | |
train = new Scanner( | |
new BufferedReader( | |
new FileReader(args[0]))); | |
for (String token: Iter.string(train)) | |
if (!token.isEmpty()) { | |
boolean upper = Character.isUpperCase(token.charAt(0)); | |
for (String feature: features(token.toLowerCase())) | |
count(feature, upper); | |
} | |
} finally { | |
if (train != null) | |
train.close(); | |
} | |
for (Map.Entry<String, int[]> entry: counter.entrySet()) | |
System.out.printf("%s %d %d\n", | |
entry.getKey(), | |
entry.getValue()[0], | |
entry.getValue()[1]); | |
} | |
public static void count(String key, boolean upper) { | |
int index = upper ? 1 : 0; | |
if (!counter.containsKey(key)) { | |
int[] v = {0, 0}; | |
counter.put(key, v); | |
} | |
//counter.get(key)[index]++; | |
counter.get(key)[index] = roughCount(counter.get(key)[index]); | |
} | |
public static String[] features(String word) { | |
String[] fs = { | |
word, | |
substring(word, 0, 2), | |
substring(word, 0, 3), | |
substring(word, -2, -1), | |
substring(word, -3, -1), | |
}; | |
return fs; | |
} | |
/** | |
* Substring indexing similar to Python slices. | |
*/ | |
public static String substring(String s, int i, int j) { | |
int n = s.length(); | |
i = i < 0 ? n + i: i; | |
j = j < 0 ? n + 1 + j : j; | |
return s.substring(Math.max(i, 0), Math.max(0, Math.min(j, n))); | |
} | |
/** | |
* Probabilistic counting, catch every log(N) occurrences. | |
*/ | |
public static int roughCount(int value) { | |
double n = Math.random(); | |
return (n < (1.0 / (1 << value))) ? value + 1 : value; | |
} | |
/** | |
* Iterable wrapper around an Iterator object. | |
*/ | |
static class Iter<T> implements Iterable<T> { | |
private Iterator<T> t = null; | |
Iter (Iterator<T> t) { | |
assert(t != null); | |
this.t = t; | |
} | |
public Iterator<T> iterator() { return t; } | |
public static Iterable<String> string(Iterator<String> t) { return new Iter<String>(t); } | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object STrain { | |
val rnd = new util.Random() | |
val map = new collection.mutable.HashMap[String, Array[Int]] | |
def main(args: Array[String]) { | |
if (args.length > 0) { | |
// Build the map of feature frequencies. | |
for { | |
line <- io.Source.fromFile(args(0)).getLines | |
token <- line.split("[ \t]+") if !token.isEmpty | |
} { | |
val upper = token(0).isUpper | |
for (feature <- features(token.toLowerCase)) | |
count(feature, upper) | |
} | |
// Print the resulting map. | |
for ((key, value) <- map) | |
println("%s %d %d".format(key, value(0), value(1))) | |
} else | |
Console.err.println("usage: STrain <filename>") | |
} | |
def count(key: String, upper: Boolean) = { | |
val idx = if (upper) 1 else 0 | |
if (!map.contains(key)) | |
map += (key -> Array(0, 0)) | |
map(key)(idx) = roughCount(map(key)(idx)) | |
} | |
def roughCount(value: Int) = { | |
if (rnd.nextDouble() < 1.0 / (1 << value)) value + 1 else value | |
} | |
def features(word: String) = Array( | |
word, | |
slice(word, 0, 2), | |
slice(word, 0, 3), | |
slice(word,-2,-1), | |
slice(word,-3,-1) | |
) | |
def slice(s: String, from: Int, upto: Int) = { | |
val n = s.length | |
val i = from + (if (from < 0) n else 0) | |
val j = upto + (if (upto < 0) n + 1 else 0) | |
s.slice(0 max i, 0 max (n min j)) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import sys, random | |
from collections import defaultdict | |
def main(args): | |
if len(args) < 2: | |
print "usage: %s <filename>" % args[0] | |
return 1 | |
counts = defaultdict(lambda: [0, 0]) | |
def count(f, i): | |
counts[f][i] = rough_count(counts[f][i]) | |
with open(args[1]) as f: | |
for token in (t for line in f for t in line.split() if t): | |
upper = token[0].isupper() | |
for feature in features(token.lower()): | |
count(feature, upper) | |
for k, (l, u) in counts.iteritems(): | |
print "%s %d %d" % (k, l, u) | |
return 0 | |
def features(word): | |
return word, word[-3:], word[:3], word[-2:], word[:2] | |
def rough_count(val): | |
return val + 1 if random.random() < 1.0 / (1 << val)\ | |
else val | |
if __name__ == '__main__': | |
sys.exit(main(sys.argv)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment