Instantly share code, notes, and snippets.
Created
August 10, 2016 01:20
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save thenomemac/ef7a2303f78b81fdb37d98c00b06a1a2 to your computer and use it in GitHub Desktop.
use of numpy memory maps for generating a list of ndarray's without pickle overhead and metadata pass through
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# By Josiah Olson | |
# Context manager to generate batches in the background via a process pool | |
# Returns float32 minibatch tensors and metadata using numpy sharedmem to avoid | |
# the thread blocking compute time needed to de-pickle large minibatches passed | |
# through multiprocessing.Queue: ~0.3 sec for 1gb minibatch | |
# This context manager will bring blocking time to ~300 microseconds per minibatch | |
# Usage: | |
# | |
# def batchFxn(seed, sharedArr): | |
# .... # generate iterable of float32 tensors in minibatch: batchArrs | |
# .... # generate data structure containing metadata: batchMeta | |
# startIdx = 0 | |
# batchSlices = [] | |
# batchShapes = [] | |
# for batchArr in batchArrs: | |
# batchShape = batchArr.shape | |
# endIdx = startIdx + np.product(batchShape) | |
# batchSlice = slice(startIdx, endIdx) | |
# sharedArr[batchSlice] = batchArr.ravel() | |
# startIdx = endIdx | |
# batchSlices.append(batchSlice) | |
# batchShapes.append(batchShape) | |
# return (batchSlices, batchShapes, batchMeta) | |
# | |
# # Assume max dimension of feature tensor dim is (100, 100, 10) | |
# # Assume max dimension of target tensor dim is (100,) | |
# # Tell context manager that shared memory must be in | |
# # the worst case: np.product((100, 1, 10 + 1)) x 4bytes | |
# with BatchGenSharedmem(batchFxn, max_dim_tup=(100, 1, 10 + 1)) as BGC: | |
# batchMeta, [featureData, targetData] = next(BGC) | |
# .... # do something with minibatch tensors and metadata | |
# BGC.release_mem() | |
import random | |
import sharedmem | |
import numpy as np | |
from multiprocessing import Process, Queue | |
class BatchGenSharedmem: | |
def __init__(self, batch_fn, max_dim_tup, seed=None, num_workers=4): | |
self.batch_fn = batch_fn | |
self.num_workers = num_workers | |
if seed is None: | |
seed = random.randint(0, 4294967295) | |
self.seed = seed | |
self.max_dim_tup = max_dim_tup | |
def __enter__(self): | |
self.jobq = Queue(maxsize=(self.num_workers + 1)) | |
self.doneq = Queue() | |
self.processes = [] | |
self.current_batch = 0 | |
self.current_next = None | |
self.max_dim_nbr = np.product(self.max_dim_tup) | |
self.arrList = [sharedmem.empty(self.max_dim_nbr, | |
dtype=np.float32) | |
for _ in range(self.num_workers + 1)] | |
np.random.seed(self.seed) | |
# Must use same seed and num_workers to be re-producible | |
self.seedList = np.random.randint(0, 4294967295, self.num_workers) | |
def produce(processNbr): | |
random.seed(self.seedList[processNbr]) | |
while True: | |
i = self.jobq.get() | |
if i is None: | |
break | |
seed = random.randint(0, 4294967295) | |
batchStats = self.batch_fn(seed, self.arrList[i]) | |
self.doneq.put((i, batchStats)) | |
for processNbr in range(self.num_workers): | |
self.jobq.put(processNbr) | |
p = Process(target=produce, args=(processNbr,)) | |
self.processes.append(p) | |
p.start() | |
self.jobq.put(self.num_workers) | |
return self | |
def __iter__(self): | |
return self | |
def __next__(self): | |
i, batchStats = self.doneq.get() | |
batchSlices, batchShapes, batchMeta = batchStats | |
sharedArr = self.arrList[i] | |
batchArrs = [] | |
for batchSlice, batchShape in zip(batchSlices, batchShapes): | |
batchArr = sharedArr[batchSlice].reshape(batchShape) | |
batchArrs.append(batchArr) | |
# self.jobq.put(i) | |
self.current_next = i | |
self.current_batch += 1 | |
return (batchMeta, batchArrs) | |
def release_mem(self): | |
self.jobq.put(self.current_next) | |
def __exit__(self, exc_type, exc_value, traceback): | |
while not self.jobq.empty(): | |
self.jobq.get() | |
while not self.doneq.empty(): | |
self.doneq.get() | |
for process in self.processes: | |
process.terminate() | |
process.join() | |
# demo is not working and must be revised! | |
# import time | |
# d = {i: np.random.random(1000*1000*10).reshape((1000, 1000, 10)).astype(np.float32) for i in range(2)} | |
# | |
# | |
# def getBatch(seed, arr): | |
# np.random.seed(seed) | |
# randint = np.random.randint(0, 2, 1)[0] | |
# # time.sleep(2) | |
# tmpdata = d[randint] | |
# arr[:] = tmpdata.ravel() | |
# return (1000, 1000, 10) | |
# | |
# out = [] | |
# with BatchGenCM(getBatch, (1000, 1000, 10), seed=333, num_workers=4) as bg: | |
# startTime = time.time() | |
# for counter in range(20): | |
# # time.sleep(1) | |
# startTimeSub = time.time() | |
# minibatch = next(bg) | |
# print('Time to get:', time.time() - startTimeSub) | |
# print('Iter Nbr:', counter) | |
# print('First item:', minibatch[0, 0, 0]) | |
# print('Shape:', minibatch.shape) | |
# out.append(minibatch) | |
# print('Time to run all batches:', time.time() - startTime) | |
# | |
# print('Len output:', len(out)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment