Skip to content

Instantly share code, notes, and snippets.

@PirosB3
Created November 18, 2013 16:23
Show Gist options
  • Save PirosB3/7530728 to your computer and use it in GitHub Desktop.
Save PirosB3/7530728 to your computer and use it in GitHub Desktop.
MNIST
import argparse
from collections import defaultdict
import cPickle, gzip
import multiprocessing
import sys
import Image
import numpy as np
from sklearn.datasets import fetch_mldata
CURRENT_SIZE = (28, 28)
def _log(text):
print text
sys.stdout.flush()
def _worker(results, queue, dest_size, lock):
_log("Worker started")
while True:
data, target = queue.get()
_log("Started reshaping target: %s" % target)
reshaped_array = data.reshape(*CURRENT_SIZE)
im = Image.fromarray(reshaped_array)
im = im.resize(dest_size)
result = np.asarray(im)
lock.acquire()
tmp = results[target]
tmp.append(result)
_log("Number for this class is %s" % len(tmp))
results[target] = tmp
lock.release()
queue.task_done()
_log("Finished reshaping target: %s" % target)
def main(resize, n_testing, n_for_class):
mnist = fetch_mldata('MNIST original')
# Create shared variables
lock = multiprocessing.Lock()
resize_shape = (resize, resize)
manager = multiprocessing.Manager()
results = manager.dict()
queue = multiprocessing.JoinableQueue()
# Create processses
for _ in range(multiprocessing.cpu_count() * 2):
p = multiprocessing.Process(target=_worker, args=(results, queue, resize_shape, lock))
p.daemon = True
p.start()
# Initialize results dictionary
for target in np.unique(mnist.target):
results[target] = []
class_count = defaultdict(int)
n_samples = mnist.data.shape[0]
for _ in xrange(n_samples):
target = mnist.target[_]
if class_count[target] < n_for_class:
class_count[target] += 1
data = mnist.data[_]
queue.put((data, target))
print "Done adding to queue"
queue.join()
import ipdb; ipdb.set_trace()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process MNIST dataset to CORTEX input')
parser.add_argument('--resize', type=int, required=True, help='resize from 28x28 to')
parser.add_argument('--n_for_class', type=int, required=True, help='Number of testing examples')
parser.add_argument('--n_testing', type=int, required=True, help='Number of testing examples')
args = parser.parse_args()
main(args.resize, args.n_testing, args.n_for_class)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment