Skip to content

Instantly share code, notes, and snippets.

@thenomemac
Created August 10, 2016 01:20
Show Gist options
  • Save thenomemac/ef7a2303f78b81fdb37d98c00b06a1a2 to your computer and use it in GitHub Desktop.
Save thenomemac/ef7a2303f78b81fdb37d98c00b06a1a2 to your computer and use it in GitHub Desktop.
use of numpy memory maps for generating a list of ndarray's without pickle overhead and metadata pass through
# By Josiah Olson
# Context manager to generate batches in the background via a process pool
# Returns float32 minibatch tensors and metadata using numpy sharedmem to avoid
# the thread blocking compute time needed to de-pickle large minibatches passed
# through multiprocessing.Queue: ~0.3 sec for 1gb minibatch
# This context manager will bring blocking time to ~300 microseconds per minibatch
# Usage:
#
# def batchFxn(seed, sharedArr):
# .... # generate iterable of float32 tensors in minibatch: batchArrs
# .... # generate data structure containing metadata: batchMeta
# startIdx = 0
# batchSlices = []
# batchShapes = []
# for batchArr in batchArrs:
# batchShape = batchArr.shape
# endIdx = startIdx + np.product(batchShape)
# batchSlice = slice(startIdx, endIdx)
# sharedArr[batchSlice] = batchArr.ravel()
# startIdx = endIdx
# batchSlices.append(batchSlice)
# batchShapes.append(batchShape)
# return (batchSlices, batchShapes, batchMeta)
#
# # Assume max dimension of feature tensor dim is (100, 100, 10)
# # Assume max dimension of target tensor dim is (100,)
# # Tell context manager that shared memory must be in
# # the worst case: np.product((100, 1, 10 + 1)) x 4bytes
# with BatchGenSharedmem(batchFxn, max_dim_tup=(100, 1, 10 + 1)) as BGC:
# batchMeta, [featureData, targetData] = next(BGC)
# .... # do something with minibatch tensors and metadata
# BGC.release_mem()
import random
import sharedmem
import numpy as np
from multiprocessing import Process, Queue
class BatchGenSharedmem:
def __init__(self, batch_fn, max_dim_tup, seed=None, num_workers=4):
self.batch_fn = batch_fn
self.num_workers = num_workers
if seed is None:
seed = random.randint(0, 4294967295)
self.seed = seed
self.max_dim_tup = max_dim_tup
def __enter__(self):
self.jobq = Queue(maxsize=(self.num_workers + 1))
self.doneq = Queue()
self.processes = []
self.current_batch = 0
self.current_next = None
self.max_dim_nbr = np.product(self.max_dim_tup)
self.arrList = [sharedmem.empty(self.max_dim_nbr,
dtype=np.float32)
for _ in range(self.num_workers + 1)]
np.random.seed(self.seed)
# Must use same seed and num_workers to be re-producible
self.seedList = np.random.randint(0, 4294967295, self.num_workers)
def produce(processNbr):
random.seed(self.seedList[processNbr])
while True:
i = self.jobq.get()
if i is None:
break
seed = random.randint(0, 4294967295)
batchStats = self.batch_fn(seed, self.arrList[i])
self.doneq.put((i, batchStats))
for processNbr in range(self.num_workers):
self.jobq.put(processNbr)
p = Process(target=produce, args=(processNbr,))
self.processes.append(p)
p.start()
self.jobq.put(self.num_workers)
return self
def __iter__(self):
return self
def __next__(self):
i, batchStats = self.doneq.get()
batchSlices, batchShapes, batchMeta = batchStats
sharedArr = self.arrList[i]
batchArrs = []
for batchSlice, batchShape in zip(batchSlices, batchShapes):
batchArr = sharedArr[batchSlice].reshape(batchShape)
batchArrs.append(batchArr)
# self.jobq.put(i)
self.current_next = i
self.current_batch += 1
return (batchMeta, batchArrs)
def release_mem(self):
self.jobq.put(self.current_next)
def __exit__(self, exc_type, exc_value, traceback):
while not self.jobq.empty():
self.jobq.get()
while not self.doneq.empty():
self.doneq.get()
for process in self.processes:
process.terminate()
process.join()
# demo is not working and must be revised!
# import time
# d = {i: np.random.random(1000*1000*10).reshape((1000, 1000, 10)).astype(np.float32) for i in range(2)}
#
#
# def getBatch(seed, arr):
# np.random.seed(seed)
# randint = np.random.randint(0, 2, 1)[0]
# # time.sleep(2)
# tmpdata = d[randint]
# arr[:] = tmpdata.ravel()
# return (1000, 1000, 10)
#
# out = []
# with BatchGenCM(getBatch, (1000, 1000, 10), seed=333, num_workers=4) as bg:
# startTime = time.time()
# for counter in range(20):
# # time.sleep(1)
# startTimeSub = time.time()
# minibatch = next(bg)
# print('Time to get:', time.time() - startTimeSub)
# print('Iter Nbr:', counter)
# print('First item:', minibatch[0, 0, 0])
# print('Shape:', minibatch.shape)
# out.append(minibatch)
# print('Time to run all batches:', time.time() - startTime)
#
# print('Len output:', len(out))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment