Skip to content

Instantly share code, notes, and snippets.

@c0rp-aubakirov
Last active January 14, 2016 05:36
Show Gist options
  • Save c0rp-aubakirov/7b2470dc7d3b58ec4c33 to your computer and use it in GitHub Desktop.
Save c0rp-aubakirov/7b2470dc7d3b58ec4c33 to your computer and use it in GitHub Desktop.
Lucene 5.3.1 Classification with Leave-one-out cross validation
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.search.*;
import org.apache.lucene.util.BytesRef;
import java.io.IOException;
import java.util.*;
import java.util.logging.Logger;
/**
* User: Sanzhar Aubakirov
* Date: 1/13/16
*/
public class Classify {
private final Logger logger = Logger.getLogger("CLASSIFIER");
/**
* It is for K-fold cross validation
* CROSS_VALIDATION_SETS_NUMBER is a number of subsets
*/
private final int CROSS_VALIDATION_SETS_NUMBER;
private int[][] confusionMatrix;
private Map<String, Integer> typeIndex;
private String[] classes;
private String[] shortNameClasses;
private Analyzer analyzer = null;
/**
* During K-fold cross validation you divide all set into K subsets
*
* @param crossValidationSetsNumber is a number of subsets for Cross Validation
* @param classes Array of classes to detect
*/
public Classify(final int crossValidationSetsNumber, final String[] classes) {
this.CROSS_VALIDATION_SETS_NUMBER = crossValidationSetsNumber;
this.classes = classes;
initConfusionMatrix(classes);
}
private void initConfusionMatrix(final String[] classes) {
confusionMatrix = new int[classes.length][classes.length];
typeIndex = new HashMap<>();
for (int i = 0; i < classes.length; i++) {
typeIndex.put(classes[i], i);
}
shortNameClasses = new String[classes.length];
for (int i = 0; i < classes.length; i++) {
final String shortName = "C" + i;
shortNameClasses[i] = shortName;
}
}
/**
* Start Classifier training and testing using K-fold cross-validation
*
* @param classifier Some implementation of Classifier interface, for example KNearestNeighborClassifier
* @param documents list of documents that will be used for training and testing classifier
* @param fieldNames array of fields that will be used by classifier. Field should exists in Document
* @param classFieldName the name of the field containing the class assigned to documents
* @throws IOException If there is a low-level I/O error.
*/
public int[][] crossValidation(final Classifier classifier, final List<Document> documents,
final String[] fieldNames,
final String classFieldName) throws IOException {
Collections.shuffle(documents); // shuffle documents
final List<List<Document>> set = getCrossValidationSets(documents);
printClassesShortNames();
final StringBuilder output = new StringBuilder();
// K-fold cross-validation
for (int i = 0; i < CROSS_VALIDATION_SETS_NUMBER; i++) {
// Create TrainingSet and index it
final List<Document> trainingSet = new ArrayList<>();
for (int k = 0; k < CROSS_VALIDATION_SETS_NUMBER; k++) {
if (k == i) continue; // do not include test set to training
trainingSet.addAll(set.get(k));
}
final MessageIndexer indexerTrain = new MessageIndexer("/tmp/train" + i);
indexerTrain.index(true, trainingSet);
// Create TestingSet and index it
final List<Document> testingSet = new ArrayList<>();
testingSet.addAll(set.get(i));
final MessageIndexer indexerTest = new MessageIndexer("/tmp/test" + i);
indexerTest.index(true, testingSet);
final IndexReader irTrain = indexerTrain.readIndex();
final LeafReader wrap = SlowCompositeReaderWrapper.wrap(irTrain);
final Analyzer analyzer = getAnalyzer();
classifier.train(wrap, fieldNames, classFieldName, analyzer, new MatchAllDocsQuery());
final IndexReader irTest = indexerTest.readIndex();
final IndexSearcher testSearcher = new IndexSearcher(irTest);
final Query qTest = new MatchAllDocsQuery();
final TopDocs testDocs = testSearcher.search(qTest, testingSet.size());
final Map<Integer, Document> idToDocumentMap = new HashMap<>();
for (ScoreDoc scoreDoc : testDocs.scoreDocs) {
idToDocumentMap.put(scoreDoc.doc, irTest.document(scoreDoc.doc));
}
checkClassifier(idToDocumentMap, classifier, fieldNames, classFieldName, output);
}
appendResults(confusionMatrix, "\n===Average result===\n", output);
logger.info(output.toString());
return confusionMatrix;
}
/**
* Class name are shortened just for better look while logging confusion matrix
*/
private void printClassesShortNames() {
final StringBuilder forShortNames = new StringBuilder();
for (int i = 0; i < classes.length; i++) {
final String shortName = "C" + i;
forShortNames.append(shortName).append("\t").append(classes[i]).append("\n");
}
logger.info("\n\tClass names are shortened\n\n" + forShortNames.toString() + "\n");
}
/**
* This is for K-fold cross validation
* <p>
* Divide list of all documents into K sets
* K = CROSS_VALIDATION_SETS_NUMBER
*/
private List<List<Document>> getCrossValidationSets(final List<Document> documents) {
final int TOTAL_SET_SIZE = documents.size();
final int EACH_SUBSET_SIZE = TOTAL_SET_SIZE / CROSS_VALIDATION_SETS_NUMBER;
final List<List<Document>> set = new ArrayList<>(CROSS_VALIDATION_SETS_NUMBER);
for (int i = 0; i < CROSS_VALIDATION_SETS_NUMBER; i++) {
final List<Document> setItem = new ArrayList<>(EACH_SUBSET_SIZE);
for (int j = i * EACH_SUBSET_SIZE; j < (i + 1) * EACH_SUBSET_SIZE; j++) {
final Document document = documents.get(j);
setItem.add(document);
}
set.add(setItem);
}
return set;
}
// http://soleami.com/blog/comparing-document-classification-functions-of-lucene-and-mahout.html
/**
* This method classifying testing documents and accumulates confusion matrix
*
* @param testingSet validation set
* @param classifier trained classifier
* @param fieldNames array of fields that will be used by classifier. Field should exists in Document
* @param classFieldName the name of the field containing the class assigned to documents
* @param output StringBuilder to print results
* @throws IOException If there is a low-level I/O error.
*/
private void checkClassifier(final Map<Integer, Document> testingSet, final Classifier classifier,
final String[] fieldNames, final String classFieldName,
StringBuilder output) throws IOException {
// init local confusion matrix
final int[][] tempConfusionMatrix = new int[classes.length][classes.length];
for (Document document : testingSet.values()) {
final String correctAnswer = document.get(classFieldName);
final int cai = typeIndex.get(correctAnswer);
final StringBuilder text = new StringBuilder();
for (String fieldName : fieldNames) {
text.append(document.get(fieldName)).append(" ");
}
final ClassificationResult<BytesRef> result = classifier.assignClass(text.toString());
final String classified = result.getAssignedClass().utf8ToString();
final int cli = typeIndex.get(classified);
tempConfusionMatrix[cai][cli]++;
}
appendResults(tempConfusionMatrix, "\n==Iteration results==\n", output);
for (int i = 0; i < tempConfusionMatrix.length; i++) {
for (int j = 0; j < tempConfusionMatrix.length; j++) {
confusionMatrix[i][j] += tempConfusionMatrix[i][j];
}
}
}
/**
* To Print confusion matrix in a human readable way
*
* @param confusionMatrix matrix itself
* @param header header of the output
*/
private void appendResults(final int[][] confusionMatrix, final String header, final StringBuilder output) {
// build matrix for output
final StringBuilder matrix = new StringBuilder();
// add header
matrix.append("\t\t");
for (String clazz : shortNameClasses) {
matrix.append(clazz).append("\t");
}
matrix.append("\n");
int fc = 0, tc = 0, total = 0;
for (int i = 0; i < typeIndex.size(); i++) {
matrix.append("\t").append(shortNameClasses[i]).append("\t");
for (int j = 0; j < typeIndex.size(); j++) {
matrix.append(confusionMatrix[i][j]).append("\t");
if (i == j) {
tc += confusionMatrix[i][j];
} else {
fc += confusionMatrix[i][j];
}
total += confusionMatrix[i][j];
}
matrix.append("\n");
}
float accrate = (float) tc / (float) (tc + fc);
float errrate = (float) fc / (float) (tc + fc);
output.append(header)
.append("\n\tConfusion matrix:\n")
.append(matrix.toString())
.append("\n\tAccuracy rate\t")
.append(accrate)
.append("")
.append("\n\tError rate\t")
.append(errrate)
.append("")
.append("\n\tDocs #\t")
.append(total)
.append("\n==\t==\n\n");
}
/**
* Set Analyzer that will be used by Classifier to analyze document fields
*
* @param analyzer the analyzer used to tokenize / filter the unseen text
*/
public void setAnalyzerWrapper(Analyzer analyzer) {
this.analyzer = analyzer;
}
private Analyzer getAnalyzer() {
if (analyzer == null) {
return new StandardAnalyzer();
}
return analyzer;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment