Skip to content

Instantly share code, notes, and snippets.

@BlGene
Created February 17, 2017 09:38
Show Gist options
  • Save BlGene/708190838ef0835f4c349b651bb72557 to your computer and use it in GitHub Desktop.
Save BlGene/708190838ef0835f4c349b651bb72557 to your computer and use it in GitHub Desktop.
class DataProvider(object):
def __init__(self,label_dict, top_blobs = None):
# for determinist evaluation
self.random_state = None
self.label_dict = label_dict
if top_blobs is None:
x = label_dict.values()[0]
if type(x) == int:
self.top_blobs = ["img", "label"]
x = (x,)
self.num_classes = (max(self.label_dict.values())+1,)
else:
raise ValueError
self.label_example = x
else:
self.top_blobs = top_blobs
self.label_example = None
def getDatumIds(self):
return self.label_dict.keys()
def getLabel(self, did):
return self.label_dict[did]
def get_epoch(self):
pass
def get_batch_vec(self):
pass
def get_datum(self):
pass
def augment_datum(self, datum):
return datum
def assign_batch(self, net_or_top):
data,_ = self.get_batch_vec()
if isinstance(net_or_top, caffe.Net):
net = net_or_top
for i, blob_name in enumerate(self.top_blobs):
net.blobs[blob_name].data[...] = data[i]
elif isinstance(net_or_top, caffe._caffe.RawBlobVec):
top = net_or_top
for i in range(len(top)):
top[i].data[...] = data[i]
else:
raise ValueError
return data, _
def __exit__(self):
pass
class PrefetchingDataProvider():
def __init__(self, dp, debug=False):
# save a copy of the DP locally, make sure this is pickelable
self.dp = dp
self.queue_len = config.PREFETCH_PROCESSES
self.debug = debug
self.queue = deque()
# don't start workers after ctrl-c
def init_worker():
signal.signal(signal.SIGINT, signal.SIG_IGN)
# start 1 worker processes
try:
self.pool = Pool(processes=self.queue_len,
initializer=init_worker)
except OSError:
set_trace()
# Needs to be a duplicate of BatchDataProvider.get_batch_vec
def get_batch_vec(self):
if self.debug:
batch_recipe = self.dp.get_batch_recipe()
return call_create_batch(self.dp, batch_recipe)
for _ in range(self.queue_len - len(self.queue)):
# call create_batch function
batch_recipe = self.dp.get_batch_recipe()
# call some random function (so that self.dp.random_state is incremented)
# because get_batch_recipe does *not* call dp.random_state.
if self.dp.random_state:
self.dp.random_state.rand()
# put some things in the que
self.queue.append(self.pool.apply_async(call_create_batch,
(self.dp, batch_recipe)))
# take one things out of the que
#start = clock()
res = self.queue.popleft().get()
#time = clock() - start
#print(time)
return res
def __enter__(self):
pass
def __exit__(self):
# this was causing hangs in docker (for some reasons)
#self.pool.terminate()
self.pool.close()
for entry in self.queue():
entry.get(timeout=10)
self.pool.join()
class AugmentBasic(DataProvider):
def __init__(self):
self.offset = config.IMAGE_MEAN
self.scale_factor = config.IMAGE_SCALE
# Composite data provider
self.image_size = config.IMAGE_SIZE
self.augment = True
mode = self.mode_type
assert mode in ('train','val')
if mode == 'train':
self.augment = True
elif mode == 'val' or mode == 'test':
self.augment = False
else:
raise ValueError
if self.augment:
assert(self.random_state is not None)
self.rot = 180
self.shift = 100
def get_datum(self, did):
image = self.load_file(did, as_pil=True)
if self.augment:
mirror = self.random_state.rand() > .5
x,y = 2*(self.random_state.rand(2)-.5)*self.shift
rr = 2*(self.random_state.rand()-.5)*self.rot
try:
image = PIL.ImageChops.offset(image,xoffset=int(x), yoffset=int(y))
except AttributeError:
image = image.offset(xoffset=int(x), yoffset=int(y))
image = image.rotate(rr, PIL.Image.BICUBIC)
#plt.imshow(image)
#plt.show()
# shared augmentation
arr = np.array(image).transpose(2,0,1)
if self.augment and mirror:
arr = arr[:,:,::-1]
arr = (arr - self.offset) * self.scale_factor
return arr
class FileDataProviderLayer(caffe.Layer):
"""
Provide input data for VQA.
"""
def setup(self, bottom, top):
# config from python code
self.batch_size = config.BATCH_SIZE
# config from proto file
params = json.loads(self.param_str)
self.x = params["x"]
dp = FileDataProvider(x = self.x)
if not config.DEBUG:
dp = PrefetchingDataProvider(dp)
self.dp = dp
def reshape(self, bottom, top):
if self.mode_type in ('val', 'test-dev', 'test'):
res, _ = self.dp.get_batch_vec()
for i, f in enumerate(res):
top[i].reshape(*f.shape)
else:
res, _ = self.dp.get_batch_vec()
for i, f in enumerate(res):
top[i].reshape(*f.shape)
top[i].data[...] = f
def forward(self, bottom, top):
pass
def backward(self, top, propagate_down, bottom):
pass
@BlGene
Copy link
Author

BlGene commented Feb 17, 2017

This code is a mess (sorry) a bunch of functions are missing and reshape should be using assign_batch.
I can post a more complete version in case someone is seriously interested in making this a general data layer template.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment