Created
February 17, 2017 09:38
-
-
Save BlGene/708190838ef0835f4c349b651bb72557 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataProvider(object): | |
def __init__(self,label_dict, top_blobs = None): | |
# for determinist evaluation | |
self.random_state = None | |
self.label_dict = label_dict | |
if top_blobs is None: | |
x = label_dict.values()[0] | |
if type(x) == int: | |
self.top_blobs = ["img", "label"] | |
x = (x,) | |
self.num_classes = (max(self.label_dict.values())+1,) | |
else: | |
raise ValueError | |
self.label_example = x | |
else: | |
self.top_blobs = top_blobs | |
self.label_example = None | |
def getDatumIds(self): | |
return self.label_dict.keys() | |
def getLabel(self, did): | |
return self.label_dict[did] | |
def get_epoch(self): | |
pass | |
def get_batch_vec(self): | |
pass | |
def get_datum(self): | |
pass | |
def augment_datum(self, datum): | |
return datum | |
def assign_batch(self, net_or_top): | |
data,_ = self.get_batch_vec() | |
if isinstance(net_or_top, caffe.Net): | |
net = net_or_top | |
for i, blob_name in enumerate(self.top_blobs): | |
net.blobs[blob_name].data[...] = data[i] | |
elif isinstance(net_or_top, caffe._caffe.RawBlobVec): | |
top = net_or_top | |
for i in range(len(top)): | |
top[i].data[...] = data[i] | |
else: | |
raise ValueError | |
return data, _ | |
def __exit__(self): | |
pass | |
class PrefetchingDataProvider(): | |
def __init__(self, dp, debug=False): | |
# save a copy of the DP locally, make sure this is pickelable | |
self.dp = dp | |
self.queue_len = config.PREFETCH_PROCESSES | |
self.debug = debug | |
self.queue = deque() | |
# don't start workers after ctrl-c | |
def init_worker(): | |
signal.signal(signal.SIGINT, signal.SIG_IGN) | |
# start 1 worker processes | |
try: | |
self.pool = Pool(processes=self.queue_len, | |
initializer=init_worker) | |
except OSError: | |
set_trace() | |
# Needs to be a duplicate of BatchDataProvider.get_batch_vec | |
def get_batch_vec(self): | |
if self.debug: | |
batch_recipe = self.dp.get_batch_recipe() | |
return call_create_batch(self.dp, batch_recipe) | |
for _ in range(self.queue_len - len(self.queue)): | |
# call create_batch function | |
batch_recipe = self.dp.get_batch_recipe() | |
# call some random function (so that self.dp.random_state is incremented) | |
# because get_batch_recipe does *not* call dp.random_state. | |
if self.dp.random_state: | |
self.dp.random_state.rand() | |
# put some things in the que | |
self.queue.append(self.pool.apply_async(call_create_batch, | |
(self.dp, batch_recipe))) | |
# take one things out of the que | |
#start = clock() | |
res = self.queue.popleft().get() | |
#time = clock() - start | |
#print(time) | |
return res | |
def __enter__(self): | |
pass | |
def __exit__(self): | |
# this was causing hangs in docker (for some reasons) | |
#self.pool.terminate() | |
self.pool.close() | |
for entry in self.queue(): | |
entry.get(timeout=10) | |
self.pool.join() | |
class AugmentBasic(DataProvider): | |
def __init__(self): | |
self.offset = config.IMAGE_MEAN | |
self.scale_factor = config.IMAGE_SCALE | |
# Composite data provider | |
self.image_size = config.IMAGE_SIZE | |
self.augment = True | |
mode = self.mode_type | |
assert mode in ('train','val') | |
if mode == 'train': | |
self.augment = True | |
elif mode == 'val' or mode == 'test': | |
self.augment = False | |
else: | |
raise ValueError | |
if self.augment: | |
assert(self.random_state is not None) | |
self.rot = 180 | |
self.shift = 100 | |
def get_datum(self, did): | |
image = self.load_file(did, as_pil=True) | |
if self.augment: | |
mirror = self.random_state.rand() > .5 | |
x,y = 2*(self.random_state.rand(2)-.5)*self.shift | |
rr = 2*(self.random_state.rand()-.5)*self.rot | |
try: | |
image = PIL.ImageChops.offset(image,xoffset=int(x), yoffset=int(y)) | |
except AttributeError: | |
image = image.offset(xoffset=int(x), yoffset=int(y)) | |
image = image.rotate(rr, PIL.Image.BICUBIC) | |
#plt.imshow(image) | |
#plt.show() | |
# shared augmentation | |
arr = np.array(image).transpose(2,0,1) | |
if self.augment and mirror: | |
arr = arr[:,:,::-1] | |
arr = (arr - self.offset) * self.scale_factor | |
return arr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FileDataProviderLayer(caffe.Layer): | |
""" | |
Provide input data for VQA. | |
""" | |
def setup(self, bottom, top): | |
# config from python code | |
self.batch_size = config.BATCH_SIZE | |
# config from proto file | |
params = json.loads(self.param_str) | |
self.x = params["x"] | |
dp = FileDataProvider(x = self.x) | |
if not config.DEBUG: | |
dp = PrefetchingDataProvider(dp) | |
self.dp = dp | |
def reshape(self, bottom, top): | |
if self.mode_type in ('val', 'test-dev', 'test'): | |
res, _ = self.dp.get_batch_vec() | |
for i, f in enumerate(res): | |
top[i].reshape(*f.shape) | |
else: | |
res, _ = self.dp.get_batch_vec() | |
for i, f in enumerate(res): | |
top[i].reshape(*f.shape) | |
top[i].data[...] = f | |
def forward(self, bottom, top): | |
pass | |
def backward(self, top, propagate_down, bottom): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This code is a mess (sorry) a bunch of functions are missing and
reshape
should be usingassign_batch
.I can post a more complete version in case someone is seriously interested in making this a general data layer template.