Skip to content

Instantly share code, notes, and snippets.

@ebenolson
Created January 4, 2016 14:15
Show Gist options
  • Save ebenolson/072712792c46aa192797 to your computer and use it in GitHub Desktop.
Save ebenolson/072712792c46aa192797 to your computer and use it in GitHub Desktop.
# Context manager to generate batches in the background via a process pool
# Usage:
#
# def batch(seed):
# .... # generate minibatch
# return minibatch
#
# with BatchGenCM(batch) as bg:
# minibatch = next(bg)
# .... # do something with minibatch
import uuid
import os
import pickle
import hashlib
import numpy as np
from multiprocessing import Process, Queue
class BatchGenCM:
def __init__(self, batch_fn, seed=None, num_workers=8):
self.batch_fn = batch_fn
self.num_workers = num_workers
if seed is None:
seed = np.random.randint(4294967295)
self.seed = str(seed)
self.id = uuid.uuid4()
def __enter__(self):
self.jobq = Queue(maxsize=self.num_workers)
self.doneq = Queue()
self.processes = []
self.current_batch = 0
self.finished_batches = []
def produce():
while True:
n = self.jobq.get()
if n is None:
break
seed = hashlib.md5(self.seed + str(n)).hexdigest()
seed = int(seed, 16) % 4294967295
batch = self.batch_fn(seed)
with open('/run/shm/{}-{}'.format(self.id, n), 'w') as ofile:
pickle.dump(batch, ofile, protocol=pickle.HIGHEST_PROTOCOL)
self.doneq.put(n)
for i in range(self.num_workers):
self.jobq.put(i)
p = Process(target=produce)
self.processes.append(p)
p.start()
return self
def __iter__(self):
return self
def next(self):
n = self.current_batch
while n not in self.finished_batches:
i = self.doneq.get()
self.finished_batches.append(i)
fn = '/run/shm/{}-{}'.format(self.id, n)
batch = pickle.load(open(fn))
os.system('rm {}'.format(fn))
self.jobq.put(n + self.num_workers)
self.current_batch += 1
return batch
def __exit__(self, exc_type, exc_value, traceback):
for _ in range(self.num_workers):
self.jobq.put(None)
for process in self.processes:
process.join()
while not self.doneq.empty():
_ = next(self)
@crobertob
Copy link

Hello,

I was wondering this class can be used with Python 2.7? I am trying a very simple example and I get the following error:

Traceback (most recent call last):
  File "C:/Users/crobe/Google Drive/DataMiningGroup/Code/batchgen.py", line 88, in <module>
    with BatchGenCM(batch, seed=1, num_workers=2) as bg:
  File "C:/Users/crobe/Google Drive/DataMiningGroup/Code/batchgen.py", line 54, in __enter__
    p.start()
  File "C:\Anaconda2\lib\multiprocessing\process.py", line 130, in start
    self._popen = Popen(self)
  File "C:\Anaconda2\lib\multiprocessing\forking.py", line 277, in __init__
    dump(process_obj, to_child, HIGHEST_PROTOCOL)
  File "C:\Anaconda2\lib\multiprocessing\forking.py", line 199, in dump
    ForkingPickler(file, protocol).dump(obj)
  File "C:\Anaconda2\lib\pickle.py", line 224, in dump
    self.save(obj)
  File "C:\Anaconda2\lib\pickle.py", line 331, in save
    self.save_reduce(obj=obj, *rv)
  File "C:\Anaconda2\lib\pickle.py", line 425, in save_reduce
    save(state)
  File "C:\Anaconda2\lib\pickle.py", line 286, in save
    f(self, obj) # Call unbound method with explicit self
  File "C:\Anaconda2\lib\pickle.py", line 655, in save_dict
    self._batch_setitems(obj.iteritems())
  File "C:\Anaconda2\lib\pickle.py", line 687, in _batch_setitems
    save(v)
  File "C:\Anaconda2\lib\pickle.py", line 286, in save
    f(self, obj) # Call unbound method with explicit self
  File "C:\Anaconda2\lib\pickle.py", line 754, in save_global
    (obj, module, name))
pickle.PicklingError: Can't pickle <function produce at 0x0000000003591A58>: it's not found as __main__.produce

It seems that the error occurs in the multiprocessing library but I'm not sure if there is a way to fix this. Could you please help me?

Thank you!
Roberto

@thenomemac
Copy link

Made changes to add python3 support, feel free to use:

# Modified 2016-06-30 by Josiah Olson to add python3 support
# Context manager to generate batches in the background via a process pool
# Usage:
#
# def batch(seed):
#    .... # generate minibatch
#    return minibatch
#
# with BatchGenCM(batch) as bg:
#    minibatch = next(bg)
#    .... # do something with minibatch

import uuid
import os
import pickle
import hashlib
import numpy as np
from multiprocessing import Process, Queue


class BatchGenCM:
    def __init__(self, batch_fn, seed=None, num_workers=8):
        self.batch_fn = batch_fn
        self.num_workers = num_workers
        if seed is None:
            seed = np.random.randint(4294967295)
        self.seed = str(seed)
        self.id = uuid.uuid4()

    def __enter__(self):
        self.jobq = Queue(maxsize=self.num_workers)
        self.doneq = Queue()
        self.processes = []
        self.current_batch = 0
        self.finished_batches = []

        def produce():
            while True:
                n = self.jobq.get()
                if n is None:
                    break
                seed = hashlib.md5((self.seed + str(n)).encode('utf-8')).hexdigest()
                seed = int(seed, 16) % 4294967295
                batch = self.batch_fn(seed)
                with open('/run/shm/{}-{}'.format(self.id, n), 'wb') as ofile:
                    pickle.dump(batch, ofile, protocol=pickle.HIGHEST_PROTOCOL)
                self.doneq.put(n)

        for i in range(self.num_workers):
            self.jobq.put(i)

            p = Process(target=produce)
            self.processes.append(p)
            p.start()

        return self

    def __iter__(self):
        return self

    def __next__(self):
        n = self.current_batch
        while n not in self.finished_batches:
            i = self.doneq.get()
            self.finished_batches.append(i)

        fn = '/run/shm/{}-{}'.format(self.id, n)
        batch = pickle.load(open(fn, 'rb'))
        os.system('rm {}'.format(fn))

        self.jobq.put(n + self.num_workers)
        self.current_batch += 1
        return batch

    def __exit__(self, exc_type, exc_value, traceback):
        for _ in range(self.num_workers):
            self.jobq.put(None)
        for process in self.processes:
            process.join()
        while not self.doneq.empty():
            _ = self.__next__()

@zhaobozb
Copy link

Hi Ebenolson,

I have the same question as @baumgach, how can batch(seed) iterate all the samples in a round without duplicate choose?

@rslprpr
Copy link

rslprpr commented Oct 6, 2016

For large dataset on hdf5 how do you use this batch_generator without loading your data on memory?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment