Skip to content

Instantly share code, notes, and snippets.

@ezietsman
Created October 9, 2017 22:30
Show Gist options
  • Save ezietsman/28791f411b605d0e3c421b9759429ed4 to your computer and use it in GitHub Desktop.
Save ezietsman/28791f411b605d0e3c421b9759429ed4 to your computer and use it in GitHub Desktop.
Threadsafe Iterator for BSON files for kaggle use.
def grouper(n, iterable):
'''
Given an iterable, it'll return size n chunks per iteration.
Handles the last chunk too.
'''
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
class threadsafe_iter:
"""
Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return self.it.__next__()
def threadsafe_generator(f):
"""
A decorator that takes a generator function and makes it thread-safe.
"""
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
@threadsafe_generator
def get_features_label(documents, batch_size=32, return_labels=True):
'''
Given a document return X, y
X is scaled to [0, 1] and consists of all images contained in document.
y is given an integer encoding.
'''
for batch in grouper(batch_size, documents):
images = []
labels = []
for document in batch:
category = document.get('category_id', '')
img = document.get('imgs')[0]
data = io.BytesIO(img.get('picture', None))
im = imread(data)
if category:
label = labelencoder.transform([category])
else:
label = None
im = im.astype('float32') / 255.0
images.append(im)
labels.append(label)
if return_labels:
yield np.array(images), np.array(labels)
else:
yield np.array(images)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment