Created
August 10, 2022 22:43
-
-
Save petered/5bb97f751bb59475fdfaa60d771e4663 to your computer and use it in GitHub Desktop.
An algorithm for computing a series of bounding boxes from a boolean mask
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def tf_mask_to_boxes(mask, insert_fake_first_box: bool = False): | |
""" | |
Convert a boolean mask to a series of bounding boxes around each segment. | |
Warning: VERY SLOW (slower than pure python version of same algorithm) | |
:param mask: A (HxW) boolean mask | |
:param insert_fake_first_box: TFLite seems to have a bug | |
(see issue) https://github.com/tensorflow/tensorflow/issues/57084 where | |
it cannot return zero-boxes. So this flag insers a fake box of all zeros first. | |
:returns: A (Nx4) array of (Left, Top, Right, Bottom) box bounds. | |
""" | |
max_ix = tf.reduce_max(mask.shape) | |
def drag_index_right(last_ix, this_col_inputs): | |
this_mask_cell, above_cell_index, this_col_ix = this_col_inputs | |
still_active = this_mask_cell or ((last_ix!=max_ix) and (above_cell_index!=max_ix)) | |
new_last_ix = tf.minimum(tf.minimum(last_ix, this_col_ix), above_cell_index) if still_active else max_ix | |
return new_last_ix | |
def drag_row_down(above_row, this_mask_row): | |
return tf.scan( | |
drag_index_right, | |
elems=(this_mask_row, above_row, tf.range(len(this_mask_row))), | |
initializer=max_ix | |
) | |
def compute_indices_down(mask_): | |
horizontal_index_grid = tf.scan( | |
drag_row_down, | |
elems=mask_, | |
initializer=tf.fill((tf.shape(mask_)[1], ), value=max_ix) | |
) | |
active_mask = horizontal_index_grid != max_ix | |
stop_mask = tf.concat([active_mask[:, :-1] & ~active_mask[:, 1:], active_mask[:, -1][:, None]], axis=1) | |
return horizontal_index_grid, stop_mask | |
x_ixs, x_stop_mask = compute_indices_down(mask) | |
y_ixs, y_stop_mask = (tf.transpose(t) for t in compute_indices_down(tf.transpose(mask))) | |
corner_mask = x_stop_mask & y_stop_mask | |
yx_stops_ixs = tf.cast(tf.where(corner_mask), tf.int32) | |
x_starts = tf.gather_nd(x_ixs, yx_stops_ixs) | |
y_starts = tf.gather_nd(y_ixs, yx_stops_ixs) | |
ltrb_boxes = tf.concat([x_starts[:, None], y_starts[:, None], yx_stops_ixs[:, 1][:, None] + 1, yx_stops_ixs[:, 0][:, None] + 1], axis=1) | |
if insert_fake_first_box: | |
ltrb_boxes = tf.concat([tf.zeros((1, 4), dtype=tf.int32), ltrb_boxes], axis=0) | |
return ltrb_boxes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment