Last active
March 31, 2016 05:32
-
-
Save salamanders/8e7054f62b53eb772895 to your computer and use it in GitHub Desktop.
Ultra-hacky attempt to give EdwardRaff/JSAT a workout by blindly trying every classifier
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.File; | |
import java.io.IOException; | |
import java.lang.reflect.InvocationTargetException; | |
import java.lang.reflect.Modifier; | |
import java.util.List; | |
import java.util.Set; | |
import java.util.concurrent.Callable; | |
import java.util.concurrent.TimeUnit; | |
import java.util.logging.Logger; | |
import java.util.stream.Collectors; | |
import com.google.common.reflect.ClassPath; | |
import com.google.common.reflect.ClassPath.ClassInfo; | |
import com.google.common.util.concurrent.SimpleTimeLimiter; | |
import com.google.common.util.concurrent.TimeLimiter; | |
import com.google.common.util.concurrent.UncheckedTimeoutException; | |
import jsat.classifiers.ClassificationDataSet; | |
import jsat.classifiers.ClassificationModelEvaluation; | |
import jsat.classifiers.Classifier; | |
import jsat.classifiers.bayesian.NaiveBayes; | |
import jsat.classifiers.boosting.ModestAdaBoost; | |
import jsat.classifiers.svm.SupportVectorLearner; | |
import jsat.classifiers.svm.SupportVectorLearner.CacheMode; | |
import jsat.classifiers.trees.DecisionStump; | |
import jsat.distributions.kernels.KernelTrick; | |
import jsat.distributions.kernels.RBFKernel; | |
import jsat.exceptions.FailedToFitException; | |
import jsat.io.LIBSVMLoader; | |
import jsat.parameters.RandomSearch; | |
/** | |
* @author Benjamin Hill | |
*/ | |
public class TryAllClassifiers { | |
private static final String DATA_FILE = "diabetes.libsvm"; // diabetes, mushrooms; | |
private static final Logger LOG = Logger.getLogger(TryAllClassifiers.class.getName()); | |
private static final int GUESSED_SMALL_PARAM = 50; | |
private static final TimeLimiter TIME_LIMITER = new SimpleTimeLimiter(); | |
private static final Class<? extends Classifier> WEAK_CLASSIFIER = DecisionStump.class; | |
/** | |
* Try to instantiate via hacky methods. | |
* | |
* @param classifierClass | |
* @return instantiated class | |
*/ | |
private static Classifier buildClassifierInstance(Class<Classifier> classifierClass) { | |
// No arg | |
try { | |
return classifierClass.newInstance(); | |
} catch (final InstantiationException | IllegalAccessException e) { | |
// no luck | |
} | |
// Single int | |
try { | |
return classifierClass.getConstructor(Integer.TYPE).newInstance(GUESSED_SMALL_PARAM); | |
} catch (final NoSuchMethodException e) { | |
// ignore | |
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) { | |
System.err.println( | |
"Bad constructing:" + classifierClass.getCanonicalName() + " " + e.getClass() + " " + e.getMessage()); | |
} | |
// Single Classifier - use NB | |
try { | |
return classifierClass.getConstructor(Classifier.class).newInstance(WEAK_CLASSIFIER.newInstance()); | |
} catch (final NoSuchMethodException e) { | |
// ignore | |
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) { | |
System.err.println("Bad constructing:" + classifierClass.getCanonicalName() + "\t" + e.getMessage()); | |
} | |
// Classifier + Iterations | |
try { | |
return classifierClass.getConstructor(Classifier.class, Integer.TYPE).newInstance(WEAK_CLASSIFIER.newInstance(), | |
GUESSED_SMALL_PARAM); | |
} catch (final NoSuchMethodException e) { | |
// ignore | |
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) { | |
System.err.println("Bad constructing:" + classifierClass.getCanonicalName() + "\t" + e.getMessage()); | |
} | |
// KernelTrick | |
try { | |
return classifierClass.getConstructor(KernelTrick.class).newInstance(new RBFKernel(0.5)); | |
} catch (final NoSuchMethodException e) { | |
// ignore | |
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) { | |
System.err.println("Bad constructing:" + classifierClass.getCanonicalName() + "\t" + e.getMessage()); | |
} | |
// KernelTrick + param | |
try { | |
return classifierClass.getConstructor(KernelTrick.class, Integer.TYPE).newInstance(new RBFKernel(0.5), | |
GUESSED_SMALL_PARAM); | |
} catch (final NoSuchMethodException e) { | |
// ignore | |
} catch (final InvocationTargetException | InstantiationException | IllegalAccessException e) { | |
System.err.println("Bad constructing:" + classifierClass.getCanonicalName() + "\t" + e.getMessage()); | |
} | |
System.err.println("Unable to find a way to construct:" + classifierClass.getCanonicalName()); | |
return null; | |
} | |
/** | |
* Attempts to instantiate as many models as possible | |
* | |
* @return | |
* @throws IOException | |
*/ | |
@SuppressWarnings("unchecked") | |
private static Set<Classifier> getModels() throws IOException { | |
return ClassPath.from(Thread.currentThread().getContextClassLoader()) | |
.getTopLevelClassesRecursive("jsat.classifiers").stream().map(ClassInfo::load) // ClassInfo to Class | |
.filter(Classifier.class::isAssignableFrom) | |
.filter(possibleModelClass -> !Modifier.isAbstract(possibleModelClass.getModifiers())) | |
.map(possibleModelClass -> (Class<Classifier>) possibleModelClass) // cast only | |
.map(TryAllClassifiers::buildClassifierInstance) // the hard work | |
.filter(inst -> inst != null) // successes only | |
.collect(Collectors.toSet()); | |
} | |
/** | |
* A few minutes, nothing too long. | |
* | |
* @param model | |
* @return | |
*/ | |
private static String trainModel(final Classifier model) { | |
return trainAndTest(model, 10, TimeUnit.MINUTES); | |
} | |
/** | |
* Trains the model, returns bare stats. Even tries autoAddParameters to improve! | |
* | |
* @param model | |
* @return | |
* @throws IOException | |
*/ | |
private static String trainAndTest(final Classifier model, final int time, final TimeUnit unit) { | |
Callable<String> myCallable = () -> { | |
final long startTime = System.currentTimeMillis(); | |
if (model instanceof SupportVectorLearner) { | |
((SupportVectorLearner) model).setCacheMode(CacheMode.FULL);// Small dataset, so we can do this | |
} | |
final ClassificationDataSet dataset = LIBSVMLoader.loadC(new File(DATA_FILE)); | |
ClassificationModelEvaluation cme = new ClassificationModelEvaluation(model, dataset); | |
cme.evaluateCrossValidation(10); | |
final double originalErrorRate = cme.getErrorRate(); | |
// TODO: tunedErrorRate | |
/* | |
* double tunedErrorRate = 0; try { final List<ClassificationDataSet> splits = dataset.randomSplit(0.75, 0.25); | |
* final ClassificationDataSet train = splits.get(0); final ClassificationDataSet test = splits.get(1); final | |
* RandomSearch search = new RandomSearch(model, 3); // this method adds parameters, and returns the number of | |
* parameters added if (search.autoAddParameters(train) > 0) { // that way we only do the search if there are any | |
* parameters to actually tune search.trainC(dataset); Classifier tunedModel = search.getTrainedClassifier(); cme | |
* = new ClassificationModelEvaluation(tunedModel, train); cme.evaluateTestSet(test); tunedErrorRate = | |
* cme.getErrorRate(); } } catch (final FailedToFitException ex) { // ignore, it doesn't like tuning. } | |
*/ | |
final long elapsedTime = System.currentTimeMillis() - startTime; | |
return String.format("%s\t%d\t%.3f", model.getClass().getName(), elapsedTime, originalErrorRate); | |
}; | |
try { | |
return TIME_LIMITER.callWithTimeout(myCallable, time, unit, true); | |
} catch (final InterruptedException | UncheckedTimeoutException e) { | |
return String.format("%s\tTIMEOUT", model.getClass().getName()); | |
} catch (final Throwable ex) { | |
return String.format("%s\tERROR\t%s", model.getClass().getName(), ex.getMessage()); | |
} | |
} | |
/** | |
* @param args | |
* @throws Exception | |
*/ | |
public static void main(String[] args) throws Exception { | |
System.out.println(String.format("%s\t%s\t%s\t%s", "Model", "time", "errorRate", "tunedErrorRate")); | |
getModels().stream().parallel() // may overlap with the internal parallel training | |
.map(TryAllClassifiers::trainModel).forEach(System.out::println); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment