Created
May 7, 2023 14:11
-
-
Save gsoykan/e91dc28ce1679829aeac508fa7135dc2 to your computer and use it in GitHub Desktop.
given a mask draws line from start and end points in batched manner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def draw_line_in_mask_batched(self, mask, start_point, end_point, light_neighboring=True): | |
""" | |
Draws a line in the given batch of masks. | |
Args: | |
mask (torch.Tensor): Batch of masks with shape (B, H, W) | |
start_point (List of Tuples): Batch of start points with shape (B, 2) | |
end_point (List of Tuples): Batch of end points with shape (B, 2) | |
Returns: | |
torch.Tensor: Updated batch of masks with the line drawn, with shape (B, H, W) | |
""" | |
B, H, W = mask.shape | |
device = mask.device | |
start_point = torch.stack(start_point, dim=1) | |
end_point = torch.stack(end_point, dim=1) | |
# Extract x and y coordinates of start and end points | |
start_point = torch.min(torch.tensor([W - 1, H - 1], device=device), | |
torch.max(torch.tensor([0, 0], device=device), start_point)) | |
end_point = torch.min(torch.tensor([W - 1, H - 1], device=device), | |
torch.max(torch.tensor([0, 0], device=device), end_point)) | |
x0, y0 = start_point[:, 0], start_point[:, 1] | |
x1, y1 = end_point[:, 0], end_point[:, 1] | |
# Compute differences between start and end points | |
dx = torch.abs(x1 - x0) | |
dy = torch.abs(y1 - y0) | |
# Determine direction of the line | |
sx = torch.where(x0 < x1, torch.ones_like(x0), -torch.ones_like(x0)) | |
sy = torch.where(y0 < y1, torch.ones_like(y0), -torch.ones_like(y0)) | |
# Compute the initial error | |
error = dx - dy | |
# Initialize the current position to the start point | |
x, y = x0.clone(), y0.clone() | |
# Loop until we reach the end point | |
while ((x != x1) & (y != y1)).any(): | |
# Append the current position to the list of points | |
mask[torch.arange(B), y, x] = 1 | |
if light_neighboring: | |
for n_x, n_y in [(1, 0), (-1, 0), (0, 1), (0, -1)]: | |
n_x = x + n_x | |
n_y = y + n_y | |
valid_x = (n_x >= 0) & (n_x < W) | |
valid_y = (n_y >= 0) & (n_y < H) | |
valid = valid_x & valid_y | |
mask[torch.arange(B)[valid], n_y[valid], n_x[valid]] = 1 | |
# Compute the error for the next position | |
e2 = 2 * error | |
# Determine which direction to move | |
i_e_dy = e2 > -dy | |
error[i_e_dy] -= dy[i_e_dy] | |
x[i_e_dy] += sx[i_e_dy] | |
x[i_e_dy] = torch.clamp(x[i_e_dy], torch.min(x0[i_e_dy], x1[i_e_dy]), torch.max(x0[i_e_dy], x1[i_e_dy])) | |
i_e_dx = e2 < dx | |
error[i_e_dx] += dx[i_e_dx] | |
y[i_e_dx] += sy[i_e_dx] | |
y[i_e_dx] = torch.clamp(y[i_e_dx], torch.min(y0[i_e_dx], y1[i_e_dx]), torch.max(y0[i_e_dx], y1[i_e_dx])) | |
# Append the final position to the list of points | |
mask[torch.arange(B), y1, x1] = 1 | |
return mask |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment