Created
April 1, 2010 12:25
-
-
Save satra/351740 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import mvpa.suite as ms | |
# create xor dataset | |
samples = np.random.rand(100,2)-0.5 | |
targets = np.sign(samples) | |
targets = targets[:,0] == targets[:,1] | |
# display samples | |
idx = np.nonzero(targets==1) | |
plt.plot(samples[idx[0],0], samples[idx[0],1],'rs') | |
idx = np.nonzero(targets==0) | |
plt.plot(samples[idx[0],0], samples[idx[0],1],'bs') | |
# create dataset | |
ds = ms.dataset_wizard(samples, targets=targets, chunks=range(samples.shape[0])) | |
# choose classifier | |
clf = ms.kNN() # SMLR, LinearCSVMC, etc.,. | |
terr = ms.TransferError(clf) | |
# cross-validation with splitter | |
cvte = ms.CrossValidatedTransferError(terr, splitter=ms.HalfSplitter(), | |
enable_ca=['confusion']) | |
res = cvte(ds) | |
np.mean(res) # mean cv error | |
print cvte.ca.confusion | |
plt.figure() | |
cvte.ca.confusion.plot() | |
plt.show() | |
# different splitter - leave one out | |
cvte = ms.CrossValidatedTransferError(terr, splitter=ms.NFoldSplitter(), | |
enable_ca=['confusion']) | |
res = cvte(ds) | |
np.mean(res) # mean cv error | |
print cvte.ca.confusion | |
plt.figure() | |
cvte.ca.confusion.plot() | |
plt.show() | |
# cycle through all binary svm classifiers | |
for clf in ms.clfswh['binary','svm']: | |
terr = ms.TransferError(clf) | |
cvte = ms.CrossValidatedTransferError(terr, | |
splitter=ms.HalfSplitter(npertarget='equal', | |
nrunspersplit=10), | |
enable_ca=['confusion']) | |
try: | |
res = cvte(ds) | |
np.mean(res) # mean cv error | |
#print cvte.ca.confusion | |
#plt.figure() | |
#cvte.ca.confusion.plot() | |
#plt.suptitle(str(clf)) | |
print str(clf), cvte.ca.confusion.error | |
print cvte.ca.confusion.matrix | |
if cvte.ca.confusion.error < error: | |
error = cvte.ca.confusion.error | |
bestclf = clf.clone() | |
except: | |
print "could not run classifier: %s" % str(clf) | |
terr = ms.TransferError(bestclf) | |
cvte = ms.CrossValidatedTransferError(terr, | |
splitter=ms.HalfSplitter(npertarget='equal', | |
nrunspersplit=10), | |
enable_ca=['confusion']) | |
res = cvte(ds) | |
np.mean(res) # mean cv error | |
print cvte.ca.confusion | |
plt.figure() | |
cvte.ca.confusion.plot(numbers=True) | |
plt.suptitle(cvte.transerror.clf) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment