Last active
January 14, 2016 05:36
-
-
Save c0rp-aubakirov/7b2470dc7d3b58ec4c33 to your computer and use it in GitHub Desktop.
Lucene 5.3.1 Classification with Leave-one-out cross validation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.lucene.analysis.Analyzer; | |
import org.apache.lucene.analysis.standard.StandardAnalyzer; | |
import org.apache.lucene.classification.ClassificationResult; | |
import org.apache.lucene.classification.Classifier; | |
import org.apache.lucene.document.Document; | |
import org.apache.lucene.index.IndexReader; | |
import org.apache.lucene.index.LeafReader; | |
import org.apache.lucene.index.SlowCompositeReaderWrapper; | |
import org.apache.lucene.search.*; | |
import org.apache.lucene.util.BytesRef; | |
import java.io.IOException; | |
import java.util.*; | |
import java.util.logging.Logger; | |
/** | |
* User: Sanzhar Aubakirov | |
* Date: 1/13/16 | |
*/ | |
public class Classify { | |
private final Logger logger = Logger.getLogger("CLASSIFIER"); | |
/** | |
* It is for K-fold cross validation | |
* CROSS_VALIDATION_SETS_NUMBER is a number of subsets | |
*/ | |
private final int CROSS_VALIDATION_SETS_NUMBER; | |
private int[][] confusionMatrix; | |
private Map<String, Integer> typeIndex; | |
private String[] classes; | |
private String[] shortNameClasses; | |
private Analyzer analyzer = null; | |
/** | |
* During K-fold cross validation you divide all set into K subsets | |
* | |
* @param crossValidationSetsNumber is a number of subsets for Cross Validation | |
* @param classes Array of classes to detect | |
*/ | |
public Classify(final int crossValidationSetsNumber, final String[] classes) { | |
this.CROSS_VALIDATION_SETS_NUMBER = crossValidationSetsNumber; | |
this.classes = classes; | |
initConfusionMatrix(classes); | |
} | |
private void initConfusionMatrix(final String[] classes) { | |
confusionMatrix = new int[classes.length][classes.length]; | |
typeIndex = new HashMap<>(); | |
for (int i = 0; i < classes.length; i++) { | |
typeIndex.put(classes[i], i); | |
} | |
shortNameClasses = new String[classes.length]; | |
for (int i = 0; i < classes.length; i++) { | |
final String shortName = "C" + i; | |
shortNameClasses[i] = shortName; | |
} | |
} | |
/** | |
* Start Classifier training and testing using K-fold cross-validation | |
* | |
* @param classifier Some implementation of Classifier interface, for example KNearestNeighborClassifier | |
* @param documents list of documents that will be used for training and testing classifier | |
* @param fieldNames array of fields that will be used by classifier. Field should exists in Document | |
* @param classFieldName the name of the field containing the class assigned to documents | |
* @throws IOException If there is a low-level I/O error. | |
*/ | |
public int[][] crossValidation(final Classifier classifier, final List<Document> documents, | |
final String[] fieldNames, | |
final String classFieldName) throws IOException { | |
Collections.shuffle(documents); // shuffle documents | |
final List<List<Document>> set = getCrossValidationSets(documents); | |
printClassesShortNames(); | |
final StringBuilder output = new StringBuilder(); | |
// K-fold cross-validation | |
for (int i = 0; i < CROSS_VALIDATION_SETS_NUMBER; i++) { | |
// Create TrainingSet and index it | |
final List<Document> trainingSet = new ArrayList<>(); | |
for (int k = 0; k < CROSS_VALIDATION_SETS_NUMBER; k++) { | |
if (k == i) continue; // do not include test set to training | |
trainingSet.addAll(set.get(k)); | |
} | |
final MessageIndexer indexerTrain = new MessageIndexer("/tmp/train" + i); | |
indexerTrain.index(true, trainingSet); | |
// Create TestingSet and index it | |
final List<Document> testingSet = new ArrayList<>(); | |
testingSet.addAll(set.get(i)); | |
final MessageIndexer indexerTest = new MessageIndexer("/tmp/test" + i); | |
indexerTest.index(true, testingSet); | |
final IndexReader irTrain = indexerTrain.readIndex(); | |
final LeafReader wrap = SlowCompositeReaderWrapper.wrap(irTrain); | |
final Analyzer analyzer = getAnalyzer(); | |
classifier.train(wrap, fieldNames, classFieldName, analyzer, new MatchAllDocsQuery()); | |
final IndexReader irTest = indexerTest.readIndex(); | |
final IndexSearcher testSearcher = new IndexSearcher(irTest); | |
final Query qTest = new MatchAllDocsQuery(); | |
final TopDocs testDocs = testSearcher.search(qTest, testingSet.size()); | |
final Map<Integer, Document> idToDocumentMap = new HashMap<>(); | |
for (ScoreDoc scoreDoc : testDocs.scoreDocs) { | |
idToDocumentMap.put(scoreDoc.doc, irTest.document(scoreDoc.doc)); | |
} | |
checkClassifier(idToDocumentMap, classifier, fieldNames, classFieldName, output); | |
} | |
appendResults(confusionMatrix, "\n===Average result===\n", output); | |
logger.info(output.toString()); | |
return confusionMatrix; | |
} | |
/** | |
* Class name are shortened just for better look while logging confusion matrix | |
*/ | |
private void printClassesShortNames() { | |
final StringBuilder forShortNames = new StringBuilder(); | |
for (int i = 0; i < classes.length; i++) { | |
final String shortName = "C" + i; | |
forShortNames.append(shortName).append("\t").append(classes[i]).append("\n"); | |
} | |
logger.info("\n\tClass names are shortened\n\n" + forShortNames.toString() + "\n"); | |
} | |
/** | |
* This is for K-fold cross validation | |
* <p> | |
* Divide list of all documents into K sets | |
* K = CROSS_VALIDATION_SETS_NUMBER | |
*/ | |
private List<List<Document>> getCrossValidationSets(final List<Document> documents) { | |
final int TOTAL_SET_SIZE = documents.size(); | |
final int EACH_SUBSET_SIZE = TOTAL_SET_SIZE / CROSS_VALIDATION_SETS_NUMBER; | |
final List<List<Document>> set = new ArrayList<>(CROSS_VALIDATION_SETS_NUMBER); | |
for (int i = 0; i < CROSS_VALIDATION_SETS_NUMBER; i++) { | |
final List<Document> setItem = new ArrayList<>(EACH_SUBSET_SIZE); | |
for (int j = i * EACH_SUBSET_SIZE; j < (i + 1) * EACH_SUBSET_SIZE; j++) { | |
final Document document = documents.get(j); | |
setItem.add(document); | |
} | |
set.add(setItem); | |
} | |
return set; | |
} | |
// http://soleami.com/blog/comparing-document-classification-functions-of-lucene-and-mahout.html | |
/** | |
* This method classifying testing documents and accumulates confusion matrix | |
* | |
* @param testingSet validation set | |
* @param classifier trained classifier | |
* @param fieldNames array of fields that will be used by classifier. Field should exists in Document | |
* @param classFieldName the name of the field containing the class assigned to documents | |
* @param output StringBuilder to print results | |
* @throws IOException If there is a low-level I/O error. | |
*/ | |
private void checkClassifier(final Map<Integer, Document> testingSet, final Classifier classifier, | |
final String[] fieldNames, final String classFieldName, | |
StringBuilder output) throws IOException { | |
// init local confusion matrix | |
final int[][] tempConfusionMatrix = new int[classes.length][classes.length]; | |
for (Document document : testingSet.values()) { | |
final String correctAnswer = document.get(classFieldName); | |
final int cai = typeIndex.get(correctAnswer); | |
final StringBuilder text = new StringBuilder(); | |
for (String fieldName : fieldNames) { | |
text.append(document.get(fieldName)).append(" "); | |
} | |
final ClassificationResult<BytesRef> result = classifier.assignClass(text.toString()); | |
final String classified = result.getAssignedClass().utf8ToString(); | |
final int cli = typeIndex.get(classified); | |
tempConfusionMatrix[cai][cli]++; | |
} | |
appendResults(tempConfusionMatrix, "\n==Iteration results==\n", output); | |
for (int i = 0; i < tempConfusionMatrix.length; i++) { | |
for (int j = 0; j < tempConfusionMatrix.length; j++) { | |
confusionMatrix[i][j] += tempConfusionMatrix[i][j]; | |
} | |
} | |
} | |
/** | |
* To Print confusion matrix in a human readable way | |
* | |
* @param confusionMatrix matrix itself | |
* @param header header of the output | |
*/ | |
private void appendResults(final int[][] confusionMatrix, final String header, final StringBuilder output) { | |
// build matrix for output | |
final StringBuilder matrix = new StringBuilder(); | |
// add header | |
matrix.append("\t\t"); | |
for (String clazz : shortNameClasses) { | |
matrix.append(clazz).append("\t"); | |
} | |
matrix.append("\n"); | |
int fc = 0, tc = 0, total = 0; | |
for (int i = 0; i < typeIndex.size(); i++) { | |
matrix.append("\t").append(shortNameClasses[i]).append("\t"); | |
for (int j = 0; j < typeIndex.size(); j++) { | |
matrix.append(confusionMatrix[i][j]).append("\t"); | |
if (i == j) { | |
tc += confusionMatrix[i][j]; | |
} else { | |
fc += confusionMatrix[i][j]; | |
} | |
total += confusionMatrix[i][j]; | |
} | |
matrix.append("\n"); | |
} | |
float accrate = (float) tc / (float) (tc + fc); | |
float errrate = (float) fc / (float) (tc + fc); | |
output.append(header) | |
.append("\n\tConfusion matrix:\n") | |
.append(matrix.toString()) | |
.append("\n\tAccuracy rate\t") | |
.append(accrate) | |
.append("") | |
.append("\n\tError rate\t") | |
.append(errrate) | |
.append("") | |
.append("\n\tDocs #\t") | |
.append(total) | |
.append("\n==\t==\n\n"); | |
} | |
/** | |
* Set Analyzer that will be used by Classifier to analyze document fields | |
* | |
* @param analyzer the analyzer used to tokenize / filter the unseen text | |
*/ | |
public void setAnalyzerWrapper(Analyzer analyzer) { | |
this.analyzer = analyzer; | |
} | |
private Analyzer getAnalyzer() { | |
if (analyzer == null) { | |
return new StandardAnalyzer(); | |
} | |
return analyzer; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment