-
-
Save mmorton/b5178c144e6a64dd402097bb6d218ed0 to your computer and use it in GitHub Desktop.
test of multiprocessing with python to stream temporally coherent batches
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import glob | |
import sys | |
import time | |
import multiprocessing as mp, numpy as np, random | |
from prophesee_utils.td_video import ChronoVideo | |
import prophesee_utils.td_processing as tdp | |
import prophesee_utils.vis_utils as vis | |
import cv2 | |
from numba import njit as jit | |
@jit | |
def histogram2d(xypt, histo, max_val=1, shift=0, reset=True): | |
""" | |
accumulates spikes | |
:param xypt: | |
:param histo: | |
:param shift: spatial resolution | |
:param reset: | |
:return: | |
""" | |
if reset: | |
histo[...] = 0 | |
increment = 1.0 / max_val | |
for i in range(xypt.shape[0]): | |
x, y, p = xypt['x'][i] >> shift, xypt['y'][i] >> shift, xypt['p'][i] | |
histo[p, y, x] = min(histo[p, y, x] + increment, 1) | |
if __name__ == '__main__': | |
path = '/mnt/hdd1/detection_dataset10/train/' | |
files = glob.glob(path + '*_td.dat') | |
batchsize = 8 | |
num_threads = 4 | |
num_videos = 4 | |
num_videos_per_thread = batchsize/num_threads | |
delta_t = 10000 | |
max_q_size = 4 | |
readyQs = [mp.Queue(maxsize=max_q_size) for i in range(num_videos)] | |
array_dim = (2, 240, 320) | |
batch = np.zeros((num_videos, array_dim[0], array_dim[1], array_dim[2]), dtype=np.float32) | |
def frame_stream(i, m, n, files, shape): | |
random.shuffle(files) | |
video_num = 0 | |
video = ChronoVideo(files[0]) | |
im = np.zeros(shape, dtype=np.float32) | |
q = readyQs[i] | |
j = 0 | |
print('Queue Size: ', len(n)) | |
while 1: | |
reset = video.done | |
if video.done: | |
video_num = (video_num+1)%len(files) | |
video = ChronoVideo(files[video_num]) | |
events = video.load_delta_t(delta_t) | |
m.acquire() | |
histogram2d(events, n[j], max_val=8) | |
q.put((j, reset)) | |
j = (j+1)%max_q_size | |
array_dim2 = (max_q_size, 2, 240, 320) | |
# Create tuples of (multiprocessing.Array, numpy.ndarray) referencing the same underlying buffers | |
m_arrays = (mp.Array('f', int(np.prod(array_dim2)), lock=mp.Lock()) for _ in range(num_videos)) | |
arrays = [(m, np.frombuffer(m.get_obj(), dtype='f').reshape(array_dim2)) for m in m_arrays] | |
size = len(files)/num_videos | |
grouped_files = [files[i*size:(i+1)*size] for i in range(num_videos)] | |
procs = [mp.Process(target=frame_stream, args=(i, m, n, f, array_dim)) for i, ((m, n), f) in enumerate(zip(arrays, grouped_files))] | |
[p.start() for p in procs] | |
print('Start Streaming') | |
for _ in range(10000): | |
start = time.time() | |
for i in range(num_videos): | |
j, reset = readyQs[i].get() | |
if reset: | |
print(i, ' was reset') | |
m, arr = arrays[i] | |
batch[i] = arr[j] | |
m.release() | |
runtime = float(time.time()-start) | |
sys.stdout.write('\rtime: %f' % (runtime)) | |
sys.stdout.flush() | |
#display batch! | |
for i in range(num_videos): | |
im = vis.count_image(batch[i], max_value=1) | |
cv2.imshow('img#'+str(i), im) | |
cv2.waitKey(1) | |
[p.terminate() for p in procs] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment