Last active
December 15, 2015 18:49
-
-
Save rkwitt/5307149 to your computer and use it in GitHub Desktop.
Kitware Tech-Tip 04-03-2013 Part 3: Computing a BoW representation from a feature matrix and training / testing a SVM classifier in a N-fold cross-validation setup.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""kw-techtip-04-03-2013-p3.py | |
Compute BoW representations and train / test a SVM classifier in a N-fold | |
cross-validation setup. | |
Usage | |
----- | |
$ python kw-techtip-04-03-2013-p3.py <DataFile> <IndexFile> <ClassInfoFile> | |
""" | |
import sys | |
import numpy as np | |
# scikit-learn imports | |
from sklearn.cluster import KMeans | |
from sklearn.metrics import accuracy_score | |
from sklearn.cross_validation import KFold | |
from sklearn.cross_validation import ShuffleSplit | |
from sklearn.grid_search import GridSearchCV | |
from sklearn import preprocessing | |
from sklearn import svm | |
dat = np.genfromtxt(sys.argv[1]) # Data | |
idx = np.genfromtxt(sys.argv[2]) # Index | |
lab = np.genfromtxt(sys.argv[3]) # Class labels | |
# Create cross-validation folds | |
cv = ShuffleSplit(len(lab), n_iter=5, test_size=0.2) | |
for n, (trn, tst) in enumerate(cv): | |
p = [] | |
for i in trn: | |
p.extend(list(np.where(idx==i)[0])) | |
pos = np.asarray(p) | |
# Learn a codebook (i.e., vector-quantization of features) | |
print "[Fold: %d (%d/%d)] Learning the codebook ..." % (n, len(trn),len(tst)) | |
cb = KMeans(32, init="k-means++", n_init=10, max_iter=500) | |
cb.fit(dat[pos,:]) | |
X = np.zeros((len(trn), 32)) | |
L = np.zeros((len(trn),)) | |
for cnt, i in enumerate(trn): | |
pos = np.where(idx==i)[0] | |
X[cnt,:] = np.asarray(np.histogram(cb.predict(dat[pos,:]), | |
bins=range(0,33), | |
density=True)[0]) | |
# Cross-validate SVM classifier parameters (gamma, C) on the | |
# training portion of the data in a 5-fold cross-validation | |
# setup. | |
parameters = [{'kernel': ['rbf'], | |
'gamma': np.logspace(-6,2,10), | |
'C': [1, 10, 100, 1000]}] | |
clf = GridSearchCV(svm.SVC(C=1), parameters) | |
clf.fit(X, np.asarray(lab)[trn], cv=5) | |
# And now, test ... | |
Y = np.zeros((len(tst), 32)) | |
for cnt, i in enumerate(tst): | |
pos = np.where(idx==i)[0] | |
Y[cnt,:] = np.asarray(np.histogram(cb.predict(dat[pos,:]), | |
bins=range(0,33), density=True)[0]) | |
score = clf.score(Y,np.asarray(lab)[tst]) | |
print score*100 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment