Skip to content

Instantly share code, notes, and snippets.

@gsoykan
Created May 7, 2023 14:11
Show Gist options
  • Save gsoykan/e91dc28ce1679829aeac508fa7135dc2 to your computer and use it in GitHub Desktop.
Save gsoykan/e91dc28ce1679829aeac508fa7135dc2 to your computer and use it in GitHub Desktop.
given a mask draws line from start and end points in batched manner
def draw_line_in_mask_batched(self, mask, start_point, end_point, light_neighboring=True):
"""
Draws a line in the given batch of masks.
Args:
mask (torch.Tensor): Batch of masks with shape (B, H, W)
start_point (List of Tuples): Batch of start points with shape (B, 2)
end_point (List of Tuples): Batch of end points with shape (B, 2)
Returns:
torch.Tensor: Updated batch of masks with the line drawn, with shape (B, H, W)
"""
B, H, W = mask.shape
device = mask.device
start_point = torch.stack(start_point, dim=1)
end_point = torch.stack(end_point, dim=1)
# Extract x and y coordinates of start and end points
start_point = torch.min(torch.tensor([W - 1, H - 1], device=device),
torch.max(torch.tensor([0, 0], device=device), start_point))
end_point = torch.min(torch.tensor([W - 1, H - 1], device=device),
torch.max(torch.tensor([0, 0], device=device), end_point))
x0, y0 = start_point[:, 0], start_point[:, 1]
x1, y1 = end_point[:, 0], end_point[:, 1]
# Compute differences between start and end points
dx = torch.abs(x1 - x0)
dy = torch.abs(y1 - y0)
# Determine direction of the line
sx = torch.where(x0 < x1, torch.ones_like(x0), -torch.ones_like(x0))
sy = torch.where(y0 < y1, torch.ones_like(y0), -torch.ones_like(y0))
# Compute the initial error
error = dx - dy
# Initialize the current position to the start point
x, y = x0.clone(), y0.clone()
# Loop until we reach the end point
while ((x != x1) & (y != y1)).any():
# Append the current position to the list of points
mask[torch.arange(B), y, x] = 1
if light_neighboring:
for n_x, n_y in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
n_x = x + n_x
n_y = y + n_y
valid_x = (n_x >= 0) & (n_x < W)
valid_y = (n_y >= 0) & (n_y < H)
valid = valid_x & valid_y
mask[torch.arange(B)[valid], n_y[valid], n_x[valid]] = 1
# Compute the error for the next position
e2 = 2 * error
# Determine which direction to move
i_e_dy = e2 > -dy
error[i_e_dy] -= dy[i_e_dy]
x[i_e_dy] += sx[i_e_dy]
x[i_e_dy] = torch.clamp(x[i_e_dy], torch.min(x0[i_e_dy], x1[i_e_dy]), torch.max(x0[i_e_dy], x1[i_e_dy]))
i_e_dx = e2 < dx
error[i_e_dx] += dx[i_e_dx]
y[i_e_dx] += sy[i_e_dx]
y[i_e_dx] = torch.clamp(y[i_e_dx], torch.min(y0[i_e_dx], y1[i_e_dx]), torch.max(y0[i_e_dx], y1[i_e_dx]))
# Append the final position to the list of points
mask[torch.arange(B), y1, x1] = 1
return mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment