Last active
October 28, 2016 13:44
-
-
Save gaphex/1451f61e7dbc837faf9be0fbd369cae5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Sample usage: | |
cli = pymongo.MongoClient() | |
col = cli['wiki_answers']['gold'] | |
itr = WikianswersIterator(col, cache_size=2048) | |
for minibatch in itr: | |
process(minibatch) | |
""" | |
import random | |
import pymongo | |
import numpy as np | |
from itertools import chain | |
class WikianswersIterator(object): | |
def __init__(self, collection, cache_size=256, mode='random'): | |
self.col = collection # This is the mongo collection we want to iterate on | |
self.cache_size = cache_size # This is the buffer size that will be kept in RAM | |
self.cur_index = 0 | |
self.count = self.col.count() | |
if mode == 'random': | |
self.prepare_fun = self.prepare_random_megabatch | |
elif mode == 'sequential': | |
self.prepare_fun = self.prepare_sequential_megabatch | |
else: | |
raise ValueError | |
def __iter__(self): | |
return self | |
def next(self): | |
while True: | |
try: | |
return self.batch_iterator.next() | |
except (StopIteration, AttributeError): | |
self.current_batch = self.prepare_fun(size=self.cache_size) | |
self.batch_iterator = self.minibatch_generator(self.current_batch) | |
return self.batch_iterator.next() | |
def prepare_random_megabatch(self, size): | |
bat = list(self.col.aggregate([{"$sample": {"size": size}}])) | |
bat = np.array([np.array(t['questions']) for t in bat if len(t['questions'])]) | |
return bat | |
def prepare_sequential_megabatch(self, size): | |
if self.count - self.cur_index < size: | |
print("Out of data!") | |
raise StopIteration | |
bat = list(self.col.find().skip(self.cur_index).limit(size)) | |
bat = np.array([np.array(t['questions']) for t in bat if len(t['questions'])]) | |
self.cur_index += size | |
return bat | |
@staticmethod | |
def minibatch_generator(batch, min_size=8): | |
for bi, b in enumerate(batch): | |
try: | |
minibatch = {} | |
if len(b) <= min_size: | |
raise ValueError | |
# choose sequence | |
idx = list(range(len(b))) | |
seq_id = random.choice(idx) | |
minibatch['sequence'] = b[seq_id] | |
idx.remove(seq_id) | |
# choose positive samples for sequence | |
minibatch['pos_samples'] = list(b[idx]) | |
if len(minibatch['pos_samples']) <= min_size: | |
raise ValueError | |
# choose negative samples for sequence | |
batch_idx = list(range(len(batch))) | |
batch_idx.remove(bi) | |
minibatch['neg_samples'] = random.sample(list(chain.from_iterable(batch[batch_idx])), | |
len(minibatch['pos_samples'])) | |
yield minibatch | |
except: | |
continue |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
а зачем у тебя вынесены функции из класса?