Skip to content

Instantly share code, notes, and snippets.

@seenitall
Created May 15, 2019 09:27
Show Gist options
  • Save seenitall/e87978a0e6873f3974b1acef57d7aa15 to your computer and use it in GitHub Desktop.
Save seenitall/e87978a0e6873f3974b1acef57d7aa15 to your computer and use it in GitHub Desktop.
Polygon interpolation functions in PyTorch
import numpy as np
import cv2
import math
import torch
###### POLYGON INTERPOLATION FUNCTIONS PYTORCH ######
from torchinterp1d import Interp1d
def segments_torch(polygon):
return zip(polygon, torch.cat([polygon[1::], polygon[:1]]))
def perimeter_torch(polygon):
perimeter = abs(sum(torch.sqrt((x0-x1)**2 + (y0-y1)**2) for ((x0, y0), (x1, y1)) in segments_torch(polygon)))
return perimeter
def quantile_int_distance_torch(polygon, quantile):
distances = (torch.sqrt((x0-x1)**2 + (y0-y1)**2) for ((x0, y0), (x1, y1)) in segments_torch(polygon))
quant = abs(np.quantile(list(distances), quantile))
return quant
def n_int_distance_above_quantile_torch(polygon, int_quantile_dist):
return sum(torch.sqrt((x0-x1)**2 + (y0-y1)**2) > int_quantile_dist for ((x0, y0), (x1, y1)) in segments_torch(polygon))
def interpolate_polygon_torch(polygon, device, K=360):
# assert len(polygon.shape) == 3 and polygon.shape[0] == 1
if len(polygon.shape) == 3:
assert polygon.shape[0] == 1
polygon = polygon[0]
if polygon.shape in [(0), (1, 0)]:
new_points = torch.zeros((K, 2)).unsqueeze(0).to(device)
return new_points
if len(polygon) > K:
"""randomly select K points"""
random_ix = torch.sort(torch.multinomial(torch.arange(0, len(polygon)).float(), K))[0]
new_points = polygon[random_ix, :]
assert len(new_points) == K
new_points = new_points.unsqueeze(0).to(device)
return new_points
polygon = polygon.to(device)#, non_blocking=True)
n_points_per_int = (K - len(polygon)) / len(polygon)
remaining_points = (K - len(polygon)) % len(polygon)
per = perimeter_torch(polygon)
quantile = 1 - math.modf(n_points_per_int)[0]
quantile = torch.tensor(quantile)
# print(quantile, n_points_per_int, math.modf(n_points_per_int)[0], K, len(polygon))
int_quantile_dist = quantile_int_distance_torch(polygon, quantile)
n_above_quantile = n_int_distance_above_quantile_torch(polygon, int_quantile_dist)
try:
if n_above_quantile > remaining_points:
n_above_quantile = remaining_points
except:
print(n_above_quantile, remaining_points)
print(n_above_quantile.type(), remaining_points.type())
n_to_dist = remaining_points // n_above_quantile if n_above_quantile > 0 else 0
new_points = torch.tensor([]).to(device)
for _, (p0, p1) in enumerate(zip(polygon, torch.cat([polygon[1::], polygon[:1]]))):
n_points = n_points_per_int
dist = torch.sqrt((p0[0] - p1[0])**2 + (p0[1] - p1[1])**2)
if dist > int_quantile_dist and remaining_points > 0:
n_points += n_to_dist
remaining_points -= n_to_dist
deltas = torch.sort(torch.rand(int(n_points)))[0].to(device)
x_int = (deltas * (p1[0] - p0[0]) + p0[0]).unsqueeze(0).to(device)
x, y = torch.tensor([p0[0], p1[0]]), torch.tensor([p0[1], p1[1]])
x, y = x.to(device), y.to(device)
y_int, slope = Interp1d()(x, y, x_int)
new_int_points = torch.cat([x_int.view(x_int.shape[-1], 1), y_int.view(y_int.shape[-1], 1)], dim=1)
new_points = torch.cat([new_points, p0.unsqueeze(0)])
new_points = torch.cat([new_points, new_int_points])
# manual padding - torch.nn.functional.pad(value=) does not support 2D tensors as padding values
if len(new_points) < K:
pad_len = K - len(new_points)
last_point = new_points[-1].unsqueeze(0)
assert remaining_points > 0
for i in range(remaining_points):
new_points = torch.cat([new_points, last_point])
assert len(new_points) == K, (len(new_points), remaining_points)
new_points = new_points.unsqueeze(0)
return new_points
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment