Skip to content

Instantly share code, notes, and snippets.

@mamrehn
Last active April 28, 2023 14:53
Show Gist options
  • Save mamrehn/125d13f6638465c67578 to your computer and use it in GitHub Desktop.
Save mamrehn/125d13f6638465c67578 to your computer and use it in GitHub Desktop.
Simple example for a persistent classifier model in Weka (http://www.cs.waikato.ac.nz/ml/weka/). The Decision Tree (J48) can be exchanged with an arbitrary classifier.
//package ...;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import weka.classifiers.Classifier;
import weka.classifiers.trees.J48;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
public class WekaSaveClassifier {
protected final static String DATA_PATH = "data/unbalanced.arff";
protected final static String CLASSIFIER_PATH = "save/cls.model";
public static void main(String[] args) {
// Source: http://weka.wikispaces.com/Serialization
try {
new WekaSaveClassifier(DATA_PATH, CLASSIFIER_PATH);
} catch (Exception e) {
e.printStackTrace();
}
}
public WekaSaveClassifier(String dataPath, String classifierPath) throws Exception {
System.out.println("[LOG]\tget classifier");
Classifier cls = new J48();
System.out.println("[LOG]\tload dataset");
Instances instDataset = null;
instDataset = loadData(dataPath);
System.out.println("[LOG]\ttrain classifier with dataset");
cls.buildClassifier(instDataset);
// new unknown instance
Instance newInst = getUserInput(instDataset);
System.out.println(
"[LOG]\tclassification before saving the cls: " +
cls.classifyInstance(newInst)
);
System.out.println("[LOG]\tserialize model:\t" + classifierPath);
saveClassifier(cls, classifierPath);
System.out.println("[LOG]\tdeserialize model:\t" + classifierPath);
Classifier c = loadClassifier(classifierPath);
if (null != c)
cls = c;
else
System.err.println("Could not load classifier from " + classifierPath);
System.out.println(
"[LOG]\tclassification after reloading the cls: " +
cls.classifyInstance(newInst)
);
}
/**
* This class collects user input.
* Currently random/dummy input is generated for demonstration
* @param dataset the dataset (i.e. from an .arff file) holding the new instances feature vector structure
* @return the new instance given by the user
*/
private Instance getUserInput(final Instances dataset) {
//dataset.firstInstance().numAttributes() == 33 // -1 class
Instance in = new DenseInstance(33);
in.setDataset(dataset);
in.setClassMissing(); // not necessary
for(int i=0; i<32; ++i)
in.setValue(i, -2.5d + Math.random()*5.0d);
//System.out.println(in);
return in;
}
protected Instances loadData(final String path)
throws FileNotFoundException, IOException {
Instances inst = new Instances(new BufferedReader(new FileReader(path)));
inst.setClassIndex(inst.numAttributes() - 1);
return inst;
}
protected void saveClassifier(final Classifier cls, final String path)
throws Exception {
weka.core.SerializationHelper.write(path, cls);
}
protected Classifier loadClassifier(String classifierPath) {
Classifier cls = null;
try {
cls = (Classifier) weka.core.SerializationHelper
.read(classifierPath);
} catch (Exception e) {
e.printStackTrace();
}
return cls;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment