Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active November 1, 2016 12:52
Show Gist options
  • Save gaphex/72eabf8853ed74171a6ee54ef54ebc3e to your computer and use it in GitHub Desktop.
Save gaphex/72eabf8853ed74171a6ee54ef54ebc3e to your computer and use it in GitHub Desktop.
"""
Sample usage:
itr = WikianswersIterator(col='gold', db='wiki_answers', cache_size=2048)
for minibatch in itr:
process(minibatch)
"""
import random
import pymongo
import numpy as np
from itertools import chain
class WikianswersIterator(object):
def __init__(self, host='localhost', port=27017,
db='wiki_answers', col='gold',
cache_size=256, mode='random'):
self._cli = self._connect_to_mongo(host=host, port=port)
if db in self._cli.database_names():
self._db = self._cli[db]
else:
print("Database with name [{}] does not exist".format(db))
raise ValueError
if col in self._db.collection_names():
self._col = self._db[col]
else:
print("Could not find collection [{}] in database [{}]".format(col, db))
raise ValueError
self._cache_size = cache_size # This is the buffer size that will be kept in RAM
self._cur_index = 0
self._count = self._col.count()
if mode == 'random':
self._prepare_fun = self.prepare_random_megabatch
elif mode == 'sequential':
self._prepare_fun = self.prepare_sequential_megabatch
else:
print("Mode can be only 'random' or 'sequential'")
raise ValueError
def __iter__(self):
return self
def __next__(self):
while True:
try:
return next(self.batch_iterator)
except (StopIteration, AttributeError):
self.current_batch = self._prepare_fun(size=self._cache_size)
self.batch_iterator = self.minibatch_generator(self.current_batch)
return next(self.batch_iterator)
def prepare_random_megabatch(self, size):
bat = list(self._col.aggregate([{"$sample": {"size": size}}]))
bat = np.array([np.array(t['questions']) for t in bat if len(t['questions'])])
return bat
def prepare_sequential_megabatch(self, size):
if self._count - self._cur_index < size:
print("Out of data!")
raise StopIteration
bat = list(self._col.find().skip(self._cur_index).limit(size))
bat = np.array([np.array(t['questions']) for t in bat if len(t['questions'])])
self._cur_index += size
return bat
@staticmethod
def minibatch_generator(batch, min_size=8):
for bi, b in enumerate(batch):
try:
minibatch = {}
if len(b) <= min_size:
raise ValueError
# choose sequence
idx = list(range(len(b)))
seq_id = random.choice(idx)
minibatch['sequence'] = b[seq_id]
idx.remove(seq_id)
# choose positive samples for sequence
minibatch['pos_samples'] = list(b[idx])
if len(minibatch['pos_samples']) <= min_size:
raise ValueError
# choose negative samples for sequence
batch_idx = list(range(len(batch)))
batch_idx.remove(bi)
minibatch['neg_samples'] = random.sample(list(chain.from_iterable(batch[batch_idx])),
len(minibatch['pos_samples']))
yield minibatch
except:
continue
@staticmethod
def _connect_to_mongo(port=27017, host='localhost'):
client = pymongo.MongoClient(host, port, serverSelectionTimeoutMS = 5)
try:
dbn = client.database_names()
print('Connected to MongoDB on {}:{}'.format(host,port))
return client
except pymongo.errors.ServerSelectionTimeoutError:
print('Could not connect to mongod on {}:{}'.format(host,port))
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment