Skip to content

Instantly share code, notes, and snippets.

@rkwitt
Last active December 15, 2015 18:49
Show Gist options
  • Save rkwitt/5307149 to your computer and use it in GitHub Desktop.
Save rkwitt/5307149 to your computer and use it in GitHub Desktop.
Kitware Tech-Tip 04-03-2013 Part 3: Computing a BoW representation from a feature matrix and training / testing a SVM classifier in a N-fold cross-validation setup.
"""kw-techtip-04-03-2013-p3.py
Compute BoW representations and train / test a SVM classifier in a N-fold
cross-validation setup.
Usage
-----
$ python kw-techtip-04-03-2013-p3.py <DataFile> <IndexFile> <ClassInfoFile>
"""
import sys
import numpy as np
# scikit-learn imports
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score
from sklearn.cross_validation import KFold
from sklearn.cross_validation import ShuffleSplit
from sklearn.grid_search import GridSearchCV
from sklearn import preprocessing
from sklearn import svm
dat = np.genfromtxt(sys.argv[1]) # Data
idx = np.genfromtxt(sys.argv[2]) # Index
lab = np.genfromtxt(sys.argv[3]) # Class labels
# Create cross-validation folds
cv = ShuffleSplit(len(lab), n_iter=5, test_size=0.2)
for n, (trn, tst) in enumerate(cv):
p = []
for i in trn:
p.extend(list(np.where(idx==i)[0]))
pos = np.asarray(p)
# Learn a codebook (i.e., vector-quantization of features)
print "[Fold: %d (%d/%d)] Learning the codebook ..." % (n, len(trn),len(tst))
cb = KMeans(32, init="k-means++", n_init=10, max_iter=500)
cb.fit(dat[pos,:])
X = np.zeros((len(trn), 32))
L = np.zeros((len(trn),))
for cnt, i in enumerate(trn):
pos = np.where(idx==i)[0]
X[cnt,:] = np.asarray(np.histogram(cb.predict(dat[pos,:]),
bins=range(0,33),
density=True)[0])
# Cross-validate SVM classifier parameters (gamma, C) on the
# training portion of the data in a 5-fold cross-validation
# setup.
parameters = [{'kernel': ['rbf'],
'gamma': np.logspace(-6,2,10),
'C': [1, 10, 100, 1000]}]
clf = GridSearchCV(svm.SVC(C=1), parameters)
clf.fit(X, np.asarray(lab)[trn], cv=5)
# And now, test ...
Y = np.zeros((len(tst), 32))
for cnt, i in enumerate(tst):
pos = np.where(idx==i)[0]
Y[cnt,:] = np.asarray(np.histogram(cb.predict(dat[pos,:]),
bins=range(0,33), density=True)[0])
score = clf.score(Y,np.asarray(lab)[tst])
print score*100
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment