Created
March 1, 2017 17:02
-
-
Save psycharo-zz/59d9625c89d2f7881a7c6f0152b4182e to your computer and use it in GitHub Desktop.
custom queue runner to read cityscapes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def instance_to_regression_map(instances, cids): | |
"""Convert instance label map to the regression map | |
Args: | |
instances: instance label mask | |
cids: ids of classes to load | |
""" | |
# TODO: for all the classes that have instances, we can compute this | |
image_size = instances.shape[:2] | |
reg = np.zeros(image_size + (4,), dtype=np.uint16) | |
# instead of this, we can simply ??? | |
mask = np.zeros(image_size, dtype=np.bool) | |
for cid in cids: | |
mask |= (instances >= cid * 1000) & (instances < (cid+1) * 1000) | |
instance_ids = np.unique(instances[mask]) | |
for iid in instance_ids: | |
y, x = np.where(instances == iid) | |
reg[y,x,0] = y - np.min(y) | |
reg[y,x,1] = x - np.min(x) | |
reg[y,x,2] = np.max(y) - y | |
reg[y,x,3] = np.max(x) - x | |
return reg | |
def read_example(rgb_fname, json_fname, void_train_id=19): | |
INSTANCE_CIDS = np.arange(11, 19) | |
rgb = cv2.imread(rgb_fname)[:,:,::-1] | |
anno = json2instanceImg.Annotation() | |
anno.fromJsonFile(json_fname) | |
instances_pil = json2instanceImg.createInstanceImage(anno, 'trainIds') | |
instances = np.fromstring(instances_pil.tobytes(), dtype=np.int32) | |
instances = instances.reshape((anno.imgHeight, anno.imgWidth)) | |
seg_pil = json2labelImg.createLabelImage(anno, 'trainIds') | |
seg = np.fromstring(seg_pil.tobytes(), dtype=np.uint8) | |
seg = seg.reshape((anno.imgHeight, anno.imgWidth)) | |
seg[seg == 255] = void_train_id | |
reg = instance_to_regression_map(instances, INSTANCE_CIDS) | |
return rgb, seg, reg | |
class CityscapesRunner(object): | |
def __init__(self, filenames, src_size, num_threads, capacity=128): | |
self.filenames = filenames | |
self.num_threads = num_threads | |
self.lock = threading.Lock() | |
self.step = 0 | |
self.rgb = tf.placeholder(tf.uint8, [src_size[0], src_size[1], 3]) | |
self.seg = tf.placeholder(tf.uint8, [src_size[0], src_size[1]]) | |
self.reg = tf.placeholder(tf.uint16, [src_size[0], src_size[1], 4]) | |
self.queue = tf.FIFOQueue(capacity=capacity, | |
dtypes=[tf.uint8, tf.uint8, tf.uint16], | |
shapes=[[src_size[0], src_size[1], 3], | |
[src_size[0], src_size[1]], | |
[src_size[0], src_size[1], 4]]) | |
self.enqueue_op = self.queue.enqueue([self.rgb, self.seg, self.reg]) | |
def _data_iterator(self): | |
while True: | |
with self.lock: | |
if self.step == len(self.filenames): | |
break | |
idx = self.step | |
self.step += 1 | |
rgb_fname, json_fname = self.filenames[idx] | |
yield read_example(rgb_fname, json_fname) | |
def _run(self, sess, coord): | |
try: | |
for rgb, seg, reg in self._data_iterator(): | |
if coord and coord.should_stop(): | |
break | |
feed_dict = { | |
self.rgb : rgb, | |
self.seg : seg, | |
self.reg : reg | |
} | |
sess.run(self.enqueue_op, feed_dict) | |
except Exception as e: | |
if coord: | |
coord.request_stop(e) | |
def inputs(self): | |
return self.queue.dequeue() | |
def create_threads(self, sess, coord=None, daemon=False, start=False): | |
threads = [threading.Thread(target=self._run, args=(sess, coord,)) | |
for i in range(self.num_threads)] | |
for t in threads: | |
t.daemon = daemon | |
if start: | |
t.start() | |
if coord: | |
coord.register_thread(t) | |
return threads |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment