Created
May 5, 2016 07:25
-
-
Save atqamar/92aaad0ced779aecd092b8f0b90b884b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Parses crop file to create forward pass embeddings | |
''' | |
import os | |
import sys | |
import simplejson as json | |
import time | |
import multiprocessing as mp | |
import numpy as np | |
from sparkey import HashReader | |
model_root = '/data/validation/scripts' | |
sys.path.insert(0, model_root) | |
from lasagne_model import MultiModLasagneModel | |
from images import my_transformer as im_transform | |
from images import get_image_from_path | |
from config import product_images_root, product_metadata_path, ProductForwardsConfig | |
IMAGES_ROOT = product_images_root | |
PRODUCT_PATH = product_metadata_path | |
EMBEDDINGS_PATH = ProductForwardsConfig.path | |
hr = lambda x: HashReader('%s.spi' % x, '%s.spl' % x) | |
pretty_float = lambda f: 0 if f == 0 else f | |
def transform_worker(idx, product_data, im_path, im_shape): | |
im = get_image_from_path(im_path) | |
if im is None: | |
return | |
gender = product_data['gender'] | |
pid = product_data['id'] | |
new_im = im_transform(im, im_shape) | |
sys.stdout.write('\rTransforming image from product #%d...' % idx) | |
sys.stdout.flush() | |
return (pid, gender, new_im) | |
def batch_generator(l, n): | |
for i in xrange(0, len(l), n): | |
yield l[i:i + n] | |
def main(): | |
model_prefix = ProductForwardsConfig.model_prefix | |
model_iteration = ProductForwardsConfig.model_iteration | |
batch_size = ProductForwardsConfig.batch_size | |
buffer_n = ProductForwardsConfig.buffer_n | |
layer = ProductForwardsConfig.layer | |
product_hr = hr(PRODUCT_PATH) | |
pool = mp.Pool(mp.cpu_count() + 2) | |
product_item_embeddings = open(EMBEDDINGS_PATH, 'ab') | |
model = MultiModLasagneModel(model_prefix=model_prefix, model_iteration=model_iteration, layers=[layer]) | |
w, h = model.im_shape | |
jobs = [] | |
num_products = product_hr.__len__() | |
t0 = time.time() | |
for idx, (pid, jsondata) in enumerate(product_hr.iteritems()): | |
data = json.loads(jsondata) | |
if not data.get('gender', None): | |
continue | |
image_ids = data['image_ids'] | |
for p in image_ids: | |
image_path = os.path.join(IMAGES_ROOT, p[0], p[1], p) + '.jpeg' | |
job = pool.apply_async(transform_worker, (idx, data, image_path, (w, h))) | |
jobs.append(job) | |
if len(jobs) > buffer_n or idx == num_products - 1: | |
pool.close() | |
pool.join() | |
print "\nProcessing batches..." | |
for batch_job in batch_generator(jobs, batch_size): | |
batch_in = np.zeros((batch_size, 3, w, h), dtype=np.float32) | |
data_array = [None] * batch_size | |
for idx, job in enumerate(batch_job): | |
load = job.get() | |
if load is None: | |
continue | |
pid, gender, image = load | |
batch_in[idx] = image | |
data_array[idx] = {'pid': pid, 'gender': gender} | |
embeddings = model.create_embeddings(batch_in) | |
for idx, blobs in enumerate(embeddings): | |
if data_array[idx] is not None: | |
data = data_array[idx] | |
blob = blobs[layer] | |
data['embedding'] = blob.tolist() | |
out = json.dumps(data) | |
product_item_embeddings.write(out + '\n') | |
product_item_embeddings.flush() | |
pool = mp.Pool(mp.cpu_count() + 2) | |
jobs = [] | |
# print time stats | |
t_delta = time.time() - t0 | |
t_rem = (num_products - idx) * (t_delta / idx) | |
print 'Elapsed: %0.3fs Remaining: %0.3fs' % (t_delta, t_rem) | |
pool.close() | |
pool.join() | |
product_item_embeddings.close() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment