Last active
August 29, 2015 14:01
-
-
Save myui/5bef4e4dc8b89d3819c6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public final class TrainNewsGroups { | |
public static void main(String[] args) throws IOException { | |
File base = new File(args[0]); | |
Multiset<String> overallCounts = HashMultiset.create(); | |
int leakType = 0; | |
if (args.length > 1) { | |
leakType = Integer.parseInt(args[1]); | |
} | |
Dictionary newsGroups = new Dictionary(); | |
NewsgroupHelper helper = new NewsgroupHelper(); | |
helper.getEncoder().setProbes(2); | |
AdaptiveLogisticRegression learningAlgorithm = | |
new AdaptiveLogisticRegression(20, NewsgroupHelper.FEATURES, new L1()); | |
learningAlgorithm.setInterval(800); | |
learningAlgorithm.setAveragingWindow(500); | |
List<File> files = Lists.newArrayList(); | |
for (File newsgroup : base.listFiles()) { | |
if (newsgroup.isDirectory()) { | |
newsGroups.intern(newsgroup.getName()); | |
files.addAll(Arrays.asList(newsgroup.listFiles())); | |
} | |
} | |
Collections.shuffle(files); | |
System.out.println(files.size() + " training files"); | |
SGDInfo info = new SGDInfo(); | |
int k = 0; | |
for (File file : files) { | |
String ng = file.getParentFile().getName(); | |
int actual = newsGroups.intern(ng); | |
Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts); | |
learningAlgorithm.train(actual, v); | |
k++; | |
State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); | |
SGDHelper.analyzeState(info, leakType, k, best); | |
} | |
learningAlgorithm.close(); | |
SGDHelper.dissect(leakType, newsGroups, learningAlgorithm, files, overallCounts); | |
System.out.println("exiting main"); | |
File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group.model"); | |
ModelSerializer.writeBinary(modelFile.getAbsolutePath(), | |
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); | |
List<Integer> counts = Lists.newArrayList(); | |
System.out.println("Word counts"); | |
for (String count : overallCounts.elementSet()) { | |
counts.add(overallCounts.count(count)); | |
} | |
Collections.sort(counts, Ordering.natural().reverse()); | |
k = 0; | |
for (Integer count : counts) { | |
System.out.println(k + "\t" + count); | |
k++; | |
if (k > 1000) { | |
break; | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment