Last active
August 29, 2015 14:07
-
-
Save pqcfox/61a581d7b638fb429655 to your computer and use it in GitHub Desktop.
A program to set up the Caltech101 corpus for classification with Glimpse and SVM-Light.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy, itertools, os, random, re, shutil, string | |
from PIL import Image | |
dirs = [ f for f in os.listdir('.') if os.path.isdir(f) and '-' in f ] | |
# Split pairwise data into individual classes. | |
for d in dirs: | |
classes = d.split('-') | |
train = os.path.join(d, '{0}.train'.format(d)) | |
train_lines = open(train, 'r').readlines() | |
test = os.path.join(d, '{0}.test'.format(d)) | |
test_lines = open(test, 'r').readlines() | |
train_names = os.path.join(d, '{0}.names.train'.format(d)) | |
train_names_lines = open(train_names, 'r').readlines() | |
test_names = os.path.join(d, '{0}.names.test'.format(d)) | |
test_names_lines = open(test_names, 'r').readlines() | |
for n in range(len(classes)): | |
c = classes[n] | |
print "Generating class data for {0}...".format(c) | |
if not os.path.exists(c): | |
os.mkdir(c) | |
if n == 0: | |
svm = 1 | |
else: | |
svm = -1 | |
class_train = os.path.join(c, '{0}.train'.format(c)) | |
class_train_lines = [ str(svm) + l[1:] for l in train_lines if l[0] == str(n+1) ] | |
open(class_train, 'w').write(''.join(class_train_lines)) | |
class_test = os.path.join(c, '{0}.test'.format(c)) | |
class_test_lines = [ str(svm) + l[1:] for l in test_lines if l[0] == str(n+1) ] | |
open(class_test, 'w').write(''.join(class_test_lines)) | |
class_train_names = os.path.join(c, '{0}.names.train'.format(c)) | |
class_train_names_lines = [ '1' + l[1:] for l in train_names_lines if '/{0}/'.format(c) in l ] | |
open(class_train_names, 'w').write(''.join(class_train_names_lines)) | |
class_test_names = os.path.join(c, '{0}.names.test'.format(c)) | |
class_test_names_lines = [ '1' + l[1:] for l in test_names_lines if '/{0}/'.format(c) in l ] | |
open(class_test_names, 'w').write(''.join(class_test_names_lines)) | |
shutil.rmtree(d) | |
dirs = [ f for f in os.listdir('.') if os.path.isdir(f) ] | |
pairs = list(itertools.combinations(dirs, r=2)) | |
# Make exhaustive pairwise data from individual classes. | |
# Train an svm for each pairwise data set. | |
for pair in pairs: | |
class_trains = [ os.path.join(c, '{0}.train'.format(c)) for c in pair ] | |
class_lines = [ open(t, 'r').readlines() for t in class_trains ] | |
pair_train = '{0}.train'.format('-'.join(pair)) | |
pair_train_lines = [] | |
for lines in class_lines: | |
pair_train_lines.extend(lines) | |
open(pair_train, 'w').write(''.join(pair_train_lines)) | |
print "Training SVM using {0}...".format(pair_train) | |
os.system('../svm_light/svm_learn -v 0 {0}.train {0}.svm'.format('-'.join(pair))) | |
svms = [ f for f in os.listdir('.') if '.svm' in f ] | |
classes = sorted([ f for f in os.listdir('.') if os.path.isdir(f) ]) | |
# Use each svm to classify each class of images. | |
for c in classes: | |
for svm in svms: | |
test = os.path.join(c, '{0}.test'.format(c)) | |
pairname = svm.split('.')[0] | |
pred = '{0}-{1}.predictions'.format(c, pairname) | |
print "Testing {0} using {1}...".format(c, svm) | |
os.system('../svm_light/svm_classify -v 0 {0} {1} {2}'.format(test, svm, pred)) | |
preds = [ f for f in os.listdir('.') if '.predictions' in f ] | |
votes = [] | |
# Accumulate predictions into a list. | |
for c in classes: | |
class_preds = [ p for p in preds if p.split('-')[0] == c ] | |
class_images = os.path.join(c, '{0}.names.test'.format(c)) | |
class_images_lines = [ i.split(' ')[1] for i in open(class_images, 'r').read().splitlines() ] | |
print "Accumulating results for {0}...".format(c) | |
for n in range(len(class_preds)): | |
pred = class_preds[n] | |
class_pred_lines = open(pred, 'r').read().splitlines() | |
for line in range(len(class_pred_lines)): | |
value = float(class_pred_lines[line]) | |
pred_classes = re.findall(r'.*-(.*)-(.*)\.predictions', pred)[0] | |
if value > 0: | |
pred_class = pred_classes[0] | |
else: | |
pred_class = pred_classes[1] | |
votes.append([class_images_lines[line], c, pred_class]) | |
all_images = sorted(list(set([ v[0] for v in votes]))) | |
results = [] | |
# Count up final votes from each svm output. | |
print "Tallying final votes..." | |
for image in all_images: | |
image_votes = [ v[2] for v in votes if v[0] == image ] | |
image_class = [ v[1] for v in votes if v[0] == image ][0] | |
image_counts = [ image_votes.count(k) for k in image_votes ] | |
image_predicted = random.choice(list(set([ v for v in image_votes if image_counts[image_votes.index(v)] == max(image_counts) ]))) | |
results.append([image, image_class, image_predicted]) | |
# Format the output for results.txt. | |
print "Formatting output..." | |
titles = ['IMAGE', 'ACTUAL CATEGORY', 'PREDICTED CATEGORY'] | |
format_results = [titles] + copy.deepcopy(results) | |
maxlens = [0] * len(results[0]) | |
for row in results: | |
for n in range(len(row)): | |
maxlens[n] = max(maxlens[n], len(row[n])) | |
for row in range(len(format_results)): | |
for n in range(len(format_results[0])): | |
format_results[row][n] = string.ljust(format_results[row][n], maxlens[n]) | |
with open('results.txt', 'w') as f: | |
for line in format_results: | |
f.write('\t\t'.join(line) + '\n') | |
# Calculate overall accuracy. | |
print "Calculating accuracy..." | |
accuracy = len([ k for k in results if k[1] == k[2] ]) / float(len(results)) | |
# Generate a confusion matrix in matrix.png. | |
print "Generating confusion matrix..." | |
output_matrix = [] | |
for actual in classes: | |
outputs = [ k for k in results if k[1] == actual ] | |
row = [] | |
for predicted in classes: | |
count = len([ k for k in outputs if k[2] == predicted ]) | |
row.append(count) | |
output_matrix.append(row) | |
highest = max([ i for row in output_matrix for i in row ]) | |
for row in range(len(output_matrix)): | |
for col in range(len(output_matrix[0])): | |
output_matrix[row][col] = 255.0 * output_matrix[row][col] / highest | |
confusion = Image.new('L', (len(classes), len(classes))) | |
confusion.putdata([ i for row in output_matrix for i in row ]) | |
confusion = confusion.resize((20*len(classes), 20*len(classes))) | |
confusion.save("matrix.png") | |
# Save class names into classes.txt. | |
print "Saving class names..." | |
with open('classes.txt', 'w') as f: | |
for c in classes: | |
f.write("{0}\n".format(c)) | |
print "Done. Results in ./results.txt. Confusion matrix in ./matrix.png. Classes in ./classes.txt" | |
print "Accuracy is {0:.3f}%.".format(accuracy*100) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment