Last active
April 28, 2023 14:53
-
-
Save mamrehn/125d13f6638465c67578 to your computer and use it in GitHub Desktop.
Simple example for a persistent classifier model in Weka (http://www.cs.waikato.ac.nz/ml/weka/). The Decision Tree (J48) can be exchanged with an arbitrary classifier.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//package ...; | |
import java.io.BufferedReader; | |
import java.io.FileNotFoundException; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import weka.classifiers.Classifier; | |
import weka.classifiers.trees.J48; | |
import weka.core.DenseInstance; | |
import weka.core.Instance; | |
import weka.core.Instances; | |
public class WekaSaveClassifier { | |
protected final static String DATA_PATH = "data/unbalanced.arff"; | |
protected final static String CLASSIFIER_PATH = "save/cls.model"; | |
public static void main(String[] args) { | |
// Source: http://weka.wikispaces.com/Serialization | |
try { | |
new WekaSaveClassifier(DATA_PATH, CLASSIFIER_PATH); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
public WekaSaveClassifier(String dataPath, String classifierPath) throws Exception { | |
System.out.println("[LOG]\tget classifier"); | |
Classifier cls = new J48(); | |
System.out.println("[LOG]\tload dataset"); | |
Instances instDataset = null; | |
instDataset = loadData(dataPath); | |
System.out.println("[LOG]\ttrain classifier with dataset"); | |
cls.buildClassifier(instDataset); | |
// new unknown instance | |
Instance newInst = getUserInput(instDataset); | |
System.out.println( | |
"[LOG]\tclassification before saving the cls: " + | |
cls.classifyInstance(newInst) | |
); | |
System.out.println("[LOG]\tserialize model:\t" + classifierPath); | |
saveClassifier(cls, classifierPath); | |
System.out.println("[LOG]\tdeserialize model:\t" + classifierPath); | |
Classifier c = loadClassifier(classifierPath); | |
if (null != c) | |
cls = c; | |
else | |
System.err.println("Could not load classifier from " + classifierPath); | |
System.out.println( | |
"[LOG]\tclassification after reloading the cls: " + | |
cls.classifyInstance(newInst) | |
); | |
} | |
/** | |
* This class collects user input. | |
* Currently random/dummy input is generated for demonstration | |
* @param dataset the dataset (i.e. from an .arff file) holding the new instances feature vector structure | |
* @return the new instance given by the user | |
*/ | |
private Instance getUserInput(final Instances dataset) { | |
//dataset.firstInstance().numAttributes() == 33 // -1 class | |
Instance in = new DenseInstance(33); | |
in.setDataset(dataset); | |
in.setClassMissing(); // not necessary | |
for(int i=0; i<32; ++i) | |
in.setValue(i, -2.5d + Math.random()*5.0d); | |
//System.out.println(in); | |
return in; | |
} | |
protected Instances loadData(final String path) | |
throws FileNotFoundException, IOException { | |
Instances inst = new Instances(new BufferedReader(new FileReader(path))); | |
inst.setClassIndex(inst.numAttributes() - 1); | |
return inst; | |
} | |
protected void saveClassifier(final Classifier cls, final String path) | |
throws Exception { | |
weka.core.SerializationHelper.write(path, cls); | |
} | |
protected Classifier loadClassifier(String classifierPath) { | |
Classifier cls = null; | |
try { | |
cls = (Classifier) weka.core.SerializationHelper | |
.read(classifierPath); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
return cls; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment