Skip to content

Instantly share code, notes, and snippets.

@pqcfox
Last active August 29, 2015 14:07
Show Gist options
  • Save pqcfox/61a581d7b638fb429655 to your computer and use it in GitHub Desktop.
Save pqcfox/61a581d7b638fb429655 to your computer and use it in GitHub Desktop.
A program to set up the Caltech101 corpus for classification with Glimpse and SVM-Light.
import copy, itertools, os, random, re, shutil, string
from PIL import Image
dirs = [ f for f in os.listdir('.') if os.path.isdir(f) and '-' in f ]
# Split pairwise data into individual classes.
for d in dirs:
classes = d.split('-')
train = os.path.join(d, '{0}.train'.format(d))
train_lines = open(train, 'r').readlines()
test = os.path.join(d, '{0}.test'.format(d))
test_lines = open(test, 'r').readlines()
train_names = os.path.join(d, '{0}.names.train'.format(d))
train_names_lines = open(train_names, 'r').readlines()
test_names = os.path.join(d, '{0}.names.test'.format(d))
test_names_lines = open(test_names, 'r').readlines()
for n in range(len(classes)):
c = classes[n]
print "Generating class data for {0}...".format(c)
if not os.path.exists(c):
os.mkdir(c)
if n == 0:
svm = 1
else:
svm = -1
class_train = os.path.join(c, '{0}.train'.format(c))
class_train_lines = [ str(svm) + l[1:] for l in train_lines if l[0] == str(n+1) ]
open(class_train, 'w').write(''.join(class_train_lines))
class_test = os.path.join(c, '{0}.test'.format(c))
class_test_lines = [ str(svm) + l[1:] for l in test_lines if l[0] == str(n+1) ]
open(class_test, 'w').write(''.join(class_test_lines))
class_train_names = os.path.join(c, '{0}.names.train'.format(c))
class_train_names_lines = [ '1' + l[1:] for l in train_names_lines if '/{0}/'.format(c) in l ]
open(class_train_names, 'w').write(''.join(class_train_names_lines))
class_test_names = os.path.join(c, '{0}.names.test'.format(c))
class_test_names_lines = [ '1' + l[1:] for l in test_names_lines if '/{0}/'.format(c) in l ]
open(class_test_names, 'w').write(''.join(class_test_names_lines))
shutil.rmtree(d)
dirs = [ f for f in os.listdir('.') if os.path.isdir(f) ]
pairs = list(itertools.combinations(dirs, r=2))
# Make exhaustive pairwise data from individual classes.
# Train an svm for each pairwise data set.
for pair in pairs:
class_trains = [ os.path.join(c, '{0}.train'.format(c)) for c in pair ]
class_lines = [ open(t, 'r').readlines() for t in class_trains ]
pair_train = '{0}.train'.format('-'.join(pair))
pair_train_lines = []
for lines in class_lines:
pair_train_lines.extend(lines)
open(pair_train, 'w').write(''.join(pair_train_lines))
print "Training SVM using {0}...".format(pair_train)
os.system('../svm_light/svm_learn -v 0 {0}.train {0}.svm'.format('-'.join(pair)))
svms = [ f for f in os.listdir('.') if '.svm' in f ]
classes = sorted([ f for f in os.listdir('.') if os.path.isdir(f) ])
# Use each svm to classify each class of images.
for c in classes:
for svm in svms:
test = os.path.join(c, '{0}.test'.format(c))
pairname = svm.split('.')[0]
pred = '{0}-{1}.predictions'.format(c, pairname)
print "Testing {0} using {1}...".format(c, svm)
os.system('../svm_light/svm_classify -v 0 {0} {1} {2}'.format(test, svm, pred))
preds = [ f for f in os.listdir('.') if '.predictions' in f ]
votes = []
# Accumulate predictions into a list.
for c in classes:
class_preds = [ p for p in preds if p.split('-')[0] == c ]
class_images = os.path.join(c, '{0}.names.test'.format(c))
class_images_lines = [ i.split(' ')[1] for i in open(class_images, 'r').read().splitlines() ]
print "Accumulating results for {0}...".format(c)
for n in range(len(class_preds)):
pred = class_preds[n]
class_pred_lines = open(pred, 'r').read().splitlines()
for line in range(len(class_pred_lines)):
value = float(class_pred_lines[line])
pred_classes = re.findall(r'.*-(.*)-(.*)\.predictions', pred)[0]
if value > 0:
pred_class = pred_classes[0]
else:
pred_class = pred_classes[1]
votes.append([class_images_lines[line], c, pred_class])
all_images = sorted(list(set([ v[0] for v in votes])))
results = []
# Count up final votes from each svm output.
print "Tallying final votes..."
for image in all_images:
image_votes = [ v[2] for v in votes if v[0] == image ]
image_class = [ v[1] for v in votes if v[0] == image ][0]
image_counts = [ image_votes.count(k) for k in image_votes ]
image_predicted = random.choice(list(set([ v for v in image_votes if image_counts[image_votes.index(v)] == max(image_counts) ])))
results.append([image, image_class, image_predicted])
# Format the output for results.txt.
print "Formatting output..."
titles = ['IMAGE', 'ACTUAL CATEGORY', 'PREDICTED CATEGORY']
format_results = [titles] + copy.deepcopy(results)
maxlens = [0] * len(results[0])
for row in results:
for n in range(len(row)):
maxlens[n] = max(maxlens[n], len(row[n]))
for row in range(len(format_results)):
for n in range(len(format_results[0])):
format_results[row][n] = string.ljust(format_results[row][n], maxlens[n])
with open('results.txt', 'w') as f:
for line in format_results:
f.write('\t\t'.join(line) + '\n')
# Calculate overall accuracy.
print "Calculating accuracy..."
accuracy = len([ k for k in results if k[1] == k[2] ]) / float(len(results))
# Generate a confusion matrix in matrix.png.
print "Generating confusion matrix..."
output_matrix = []
for actual in classes:
outputs = [ k for k in results if k[1] == actual ]
row = []
for predicted in classes:
count = len([ k for k in outputs if k[2] == predicted ])
row.append(count)
output_matrix.append(row)
highest = max([ i for row in output_matrix for i in row ])
for row in range(len(output_matrix)):
for col in range(len(output_matrix[0])):
output_matrix[row][col] = 255.0 * output_matrix[row][col] / highest
confusion = Image.new('L', (len(classes), len(classes)))
confusion.putdata([ i for row in output_matrix for i in row ])
confusion = confusion.resize((20*len(classes), 20*len(classes)))
confusion.save("matrix.png")
# Save class names into classes.txt.
print "Saving class names..."
with open('classes.txt', 'w') as f:
for c in classes:
f.write("{0}\n".format(c))
print "Done. Results in ./results.txt. Confusion matrix in ./matrix.png. Classes in ./classes.txt"
print "Accuracy is {0:.3f}%.".format(accuracy*100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment