Skip to content

Instantly share code, notes, and snippets.

@atqamar
Created May 5, 2016 07:25
Show Gist options
  • Save atqamar/92aaad0ced779aecd092b8f0b90b884b to your computer and use it in GitHub Desktop.
Save atqamar/92aaad0ced779aecd092b8f0b90b884b to your computer and use it in GitHub Desktop.
'''
Parses crop file to create forward pass embeddings
'''
import os
import sys
import simplejson as json
import time
import multiprocessing as mp
import numpy as np
from sparkey import HashReader
model_root = '/data/validation/scripts'
sys.path.insert(0, model_root)
from lasagne_model import MultiModLasagneModel
from images import my_transformer as im_transform
from images import get_image_from_path
from config import product_images_root, product_metadata_path, ProductForwardsConfig
IMAGES_ROOT = product_images_root
PRODUCT_PATH = product_metadata_path
EMBEDDINGS_PATH = ProductForwardsConfig.path
hr = lambda x: HashReader('%s.spi' % x, '%s.spl' % x)
pretty_float = lambda f: 0 if f == 0 else f
def transform_worker(idx, product_data, im_path, im_shape):
im = get_image_from_path(im_path)
if im is None:
return
gender = product_data['gender']
pid = product_data['id']
new_im = im_transform(im, im_shape)
sys.stdout.write('\rTransforming image from product #%d...' % idx)
sys.stdout.flush()
return (pid, gender, new_im)
def batch_generator(l, n):
for i in xrange(0, len(l), n):
yield l[i:i + n]
def main():
model_prefix = ProductForwardsConfig.model_prefix
model_iteration = ProductForwardsConfig.model_iteration
batch_size = ProductForwardsConfig.batch_size
buffer_n = ProductForwardsConfig.buffer_n
layer = ProductForwardsConfig.layer
product_hr = hr(PRODUCT_PATH)
pool = mp.Pool(mp.cpu_count() + 2)
product_item_embeddings = open(EMBEDDINGS_PATH, 'ab')
model = MultiModLasagneModel(model_prefix=model_prefix, model_iteration=model_iteration, layers=[layer])
w, h = model.im_shape
jobs = []
num_products = product_hr.__len__()
t0 = time.time()
for idx, (pid, jsondata) in enumerate(product_hr.iteritems()):
data = json.loads(jsondata)
if not data.get('gender', None):
continue
image_ids = data['image_ids']
for p in image_ids:
image_path = os.path.join(IMAGES_ROOT, p[0], p[1], p) + '.jpeg'
job = pool.apply_async(transform_worker, (idx, data, image_path, (w, h)))
jobs.append(job)
if len(jobs) > buffer_n or idx == num_products - 1:
pool.close()
pool.join()
print "\nProcessing batches..."
for batch_job in batch_generator(jobs, batch_size):
batch_in = np.zeros((batch_size, 3, w, h), dtype=np.float32)
data_array = [None] * batch_size
for idx, job in enumerate(batch_job):
load = job.get()
if load is None:
continue
pid, gender, image = load
batch_in[idx] = image
data_array[idx] = {'pid': pid, 'gender': gender}
embeddings = model.create_embeddings(batch_in)
for idx, blobs in enumerate(embeddings):
if data_array[idx] is not None:
data = data_array[idx]
blob = blobs[layer]
data['embedding'] = blob.tolist()
out = json.dumps(data)
product_item_embeddings.write(out + '\n')
product_item_embeddings.flush()
pool = mp.Pool(mp.cpu_count() + 2)
jobs = []
# print time stats
t_delta = time.time() - t0
t_rem = (num_products - idx) * (t_delta / idx)
print 'Elapsed: %0.3fs Remaining: %0.3fs' % (t_delta, t_rem)
pool.close()
pool.join()
product_item_embeddings.close()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment