Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active October 28, 2016 13:44
Show Gist options
  • Save gaphex/1451f61e7dbc837faf9be0fbd369cae5 to your computer and use it in GitHub Desktop.
Save gaphex/1451f61e7dbc837faf9be0fbd369cae5 to your computer and use it in GitHub Desktop.
"""
Sample usage:
cli = pymongo.MongoClient()
col = cli['wiki_answers']['gold']
itr = WikianswersIterator(col, cache_size=2048)
for minibatch in itr:
process(minibatch)
"""
import random
import pymongo
import numpy as np
from itertools import chain
class WikianswersIterator(object):
def __init__(self, collection, cache_size=256, mode='random'):
self.col = collection # This is the mongo collection we want to iterate on
self.cache_size = cache_size # This is the buffer size that will be kept in RAM
self.cur_index = 0
self.count = self.col.count()
if mode == 'random':
self.prepare_fun = self.prepare_random_megabatch
elif mode == 'sequential':
self.prepare_fun = self.prepare_sequential_megabatch
else:
raise ValueError
def __iter__(self):
return self
def next(self):
while True:
try:
return self.batch_iterator.next()
except (StopIteration, AttributeError):
self.current_batch = self.prepare_fun(size=self.cache_size)
self.batch_iterator = self.minibatch_generator(self.current_batch)
return self.batch_iterator.next()
def prepare_random_megabatch(self, size):
bat = list(self.col.aggregate([{"$sample": {"size": size}}]))
bat = np.array([np.array(t['questions']) for t in bat if len(t['questions'])])
return bat
def prepare_sequential_megabatch(self, size):
if self.count - self.cur_index < size:
print("Out of data!")
raise StopIteration
bat = list(self.col.find().skip(self.cur_index).limit(size))
bat = np.array([np.array(t['questions']) for t in bat if len(t['questions'])])
self.cur_index += size
return bat
@staticmethod
def minibatch_generator(batch, min_size=8):
for bi, b in enumerate(batch):
try:
minibatch = {}
if len(b) <= min_size:
raise ValueError
# choose sequence
idx = list(range(len(b)))
seq_id = random.choice(idx)
minibatch['sequence'] = b[seq_id]
idx.remove(seq_id)
# choose positive samples for sequence
minibatch['pos_samples'] = list(b[idx])
if len(minibatch['pos_samples']) <= min_size:
raise ValueError
# choose negative samples for sequence
batch_idx = list(range(len(batch)))
batch_idx.remove(bi)
minibatch['neg_samples'] = random.sample(list(chain.from_iterable(batch[batch_idx])),
len(minibatch['pos_samples']))
yield minibatch
except:
continue
@madrugado
Copy link

а зачем у тебя вынесены функции из класса?

@gaphex
Copy link
Author

gaphex commented Oct 28, 2016

Поправил

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment