Last active
November 1, 2016 12:52
-
-
Save gaphex/72eabf8853ed74171a6ee54ef54ebc3e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Sample usage: | |
itr = WikianswersIterator(col='gold', db='wiki_answers', cache_size=2048) | |
for minibatch in itr: | |
process(minibatch) | |
""" | |
import random | |
import pymongo | |
import numpy as np | |
from itertools import chain | |
class WikianswersIterator(object): | |
def __init__(self, host='localhost', port=27017, | |
db='wiki_answers', col='gold', | |
cache_size=256, mode='random'): | |
self._cli = self._connect_to_mongo(host=host, port=port) | |
if db in self._cli.database_names(): | |
self._db = self._cli[db] | |
else: | |
print("Database with name [{}] does not exist".format(db)) | |
raise ValueError | |
if col in self._db.collection_names(): | |
self._col = self._db[col] | |
else: | |
print("Could not find collection [{}] in database [{}]".format(col, db)) | |
raise ValueError | |
self._cache_size = cache_size # This is the buffer size that will be kept in RAM | |
self._cur_index = 0 | |
self._count = self._col.count() | |
if mode == 'random': | |
self._prepare_fun = self.prepare_random_megabatch | |
elif mode == 'sequential': | |
self._prepare_fun = self.prepare_sequential_megabatch | |
else: | |
print("Mode can be only 'random' or 'sequential'") | |
raise ValueError | |
def __iter__(self): | |
return self | |
def __next__(self): | |
while True: | |
try: | |
return next(self.batch_iterator) | |
except (StopIteration, AttributeError): | |
self.current_batch = self._prepare_fun(size=self._cache_size) | |
self.batch_iterator = self.minibatch_generator(self.current_batch) | |
return next(self.batch_iterator) | |
def prepare_random_megabatch(self, size): | |
bat = list(self._col.aggregate([{"$sample": {"size": size}}])) | |
bat = np.array([np.array(t['questions']) for t in bat if len(t['questions'])]) | |
return bat | |
def prepare_sequential_megabatch(self, size): | |
if self._count - self._cur_index < size: | |
print("Out of data!") | |
raise StopIteration | |
bat = list(self._col.find().skip(self._cur_index).limit(size)) | |
bat = np.array([np.array(t['questions']) for t in bat if len(t['questions'])]) | |
self._cur_index += size | |
return bat | |
@staticmethod | |
def minibatch_generator(batch, min_size=8): | |
for bi, b in enumerate(batch): | |
try: | |
minibatch = {} | |
if len(b) <= min_size: | |
raise ValueError | |
# choose sequence | |
idx = list(range(len(b))) | |
seq_id = random.choice(idx) | |
minibatch['sequence'] = b[seq_id] | |
idx.remove(seq_id) | |
# choose positive samples for sequence | |
minibatch['pos_samples'] = list(b[idx]) | |
if len(minibatch['pos_samples']) <= min_size: | |
raise ValueError | |
# choose negative samples for sequence | |
batch_idx = list(range(len(batch))) | |
batch_idx.remove(bi) | |
minibatch['neg_samples'] = random.sample(list(chain.from_iterable(batch[batch_idx])), | |
len(minibatch['pos_samples'])) | |
yield minibatch | |
except: | |
continue | |
@staticmethod | |
def _connect_to_mongo(port=27017, host='localhost'): | |
client = pymongo.MongoClient(host, port, serverSelectionTimeoutMS = 5) | |
try: | |
dbn = client.database_names() | |
print('Connected to MongoDB on {}:{}'.format(host,port)) | |
return client | |
except pymongo.errors.ServerSelectionTimeoutError: | |
print('Could not connect to mongod on {}:{}'.format(host,port)) | |
return None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment