Created
May 3, 2017 13:30
-
-
Save mctigger/f0c9c6c44448e4d89953f15df5888c62 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from multiprocessing import Pool | |
import numpy as np | |
from skimage.transform import AffineTransform, SimilarityTransform, warp | |
center_shift = 256 / 2 | |
tf_center = SimilarityTransform(translation=-center_shift) | |
tf_uncenter = SimilarityTransform(translation=center_shift) | |
def sample_gen_random_i(): | |
for i in range(10000000000000): | |
x = np.random.rand(256, 256, 4) | |
y = [0] | |
yield x, y | |
def augment(sample): | |
x, y = sample | |
rotation = 2 * np.pi * np.random.random_sample() | |
translation = 5 * np.random.random_sample(), 5 * np.random.random_sample() | |
scale_factor = np.random.random_sample() * 0.2 + 0.9 | |
scale = scale_factor, scale_factor | |
tf_augment = AffineTransform(scale=scale, rotation=rotation, translation=translation) | |
tf = tf_center + tf_augment + tf_uncenter | |
warped_x = warp(x, tf) | |
return warped_x, y | |
def augment_parallel_sample_gen(samples): | |
p = Pool(4) | |
for sample in p.imap_unordered(augment, samples, chunksize=10): | |
yield sample | |
p.close() | |
p.join() | |
def augment_sample_gen(samples): | |
for sample in samples: | |
yield augment(sample) | |
# This is slow and the single cpu core has 100% load | |
print('Single Thread --> Slow') | |
samples = sample_gen_random_i() | |
augmented = augment_sample_gen(samples) | |
start = time.time() | |
for i, sample in enumerate(augmented): | |
print(str(i) + '|' + str(i / (time.time() - start))[:6] + ' samples / second', end='\r') | |
if i >= 2000: | |
print(str(i) + '|' + str(i / (time.time() - start))[:6] + ' samples / second') | |
break | |
# This is slow and there is only light load on the cpu cores | |
print('Multithreaded --> Slow') | |
samples = sample_gen_random_i() | |
augmented = augment_parallel_sample_gen(samples) | |
start = time.time() | |
for i, sample in enumerate(augmented): | |
print(str(i) + '|' + str(i / (time.time() - start))[:6] + ' samples / second', end='\r') | |
if i >= 2000: | |
print(str(i) + '|' + str(i / (time.time() - start))[:6] + ' samples / second') | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment