Created
May 15, 2019 09:27
-
-
Save seenitall/e87978a0e6873f3974b1acef57d7aa15 to your computer and use it in GitHub Desktop.
Polygon interpolation functions in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import cv2 | |
import math | |
import torch | |
###### POLYGON INTERPOLATION FUNCTIONS PYTORCH ###### | |
from torchinterp1d import Interp1d | |
def segments_torch(polygon): | |
return zip(polygon, torch.cat([polygon[1::], polygon[:1]])) | |
def perimeter_torch(polygon): | |
perimeter = abs(sum(torch.sqrt((x0-x1)**2 + (y0-y1)**2) for ((x0, y0), (x1, y1)) in segments_torch(polygon))) | |
return perimeter | |
def quantile_int_distance_torch(polygon, quantile): | |
distances = (torch.sqrt((x0-x1)**2 + (y0-y1)**2) for ((x0, y0), (x1, y1)) in segments_torch(polygon)) | |
quant = abs(np.quantile(list(distances), quantile)) | |
return quant | |
def n_int_distance_above_quantile_torch(polygon, int_quantile_dist): | |
return sum(torch.sqrt((x0-x1)**2 + (y0-y1)**2) > int_quantile_dist for ((x0, y0), (x1, y1)) in segments_torch(polygon)) | |
def interpolate_polygon_torch(polygon, device, K=360): | |
# assert len(polygon.shape) == 3 and polygon.shape[0] == 1 | |
if len(polygon.shape) == 3: | |
assert polygon.shape[0] == 1 | |
polygon = polygon[0] | |
if polygon.shape in [(0), (1, 0)]: | |
new_points = torch.zeros((K, 2)).unsqueeze(0).to(device) | |
return new_points | |
if len(polygon) > K: | |
"""randomly select K points""" | |
random_ix = torch.sort(torch.multinomial(torch.arange(0, len(polygon)).float(), K))[0] | |
new_points = polygon[random_ix, :] | |
assert len(new_points) == K | |
new_points = new_points.unsqueeze(0).to(device) | |
return new_points | |
polygon = polygon.to(device)#, non_blocking=True) | |
n_points_per_int = (K - len(polygon)) / len(polygon) | |
remaining_points = (K - len(polygon)) % len(polygon) | |
per = perimeter_torch(polygon) | |
quantile = 1 - math.modf(n_points_per_int)[0] | |
quantile = torch.tensor(quantile) | |
# print(quantile, n_points_per_int, math.modf(n_points_per_int)[0], K, len(polygon)) | |
int_quantile_dist = quantile_int_distance_torch(polygon, quantile) | |
n_above_quantile = n_int_distance_above_quantile_torch(polygon, int_quantile_dist) | |
try: | |
if n_above_quantile > remaining_points: | |
n_above_quantile = remaining_points | |
except: | |
print(n_above_quantile, remaining_points) | |
print(n_above_quantile.type(), remaining_points.type()) | |
n_to_dist = remaining_points // n_above_quantile if n_above_quantile > 0 else 0 | |
new_points = torch.tensor([]).to(device) | |
for _, (p0, p1) in enumerate(zip(polygon, torch.cat([polygon[1::], polygon[:1]]))): | |
n_points = n_points_per_int | |
dist = torch.sqrt((p0[0] - p1[0])**2 + (p0[1] - p1[1])**2) | |
if dist > int_quantile_dist and remaining_points > 0: | |
n_points += n_to_dist | |
remaining_points -= n_to_dist | |
deltas = torch.sort(torch.rand(int(n_points)))[0].to(device) | |
x_int = (deltas * (p1[0] - p0[0]) + p0[0]).unsqueeze(0).to(device) | |
x, y = torch.tensor([p0[0], p1[0]]), torch.tensor([p0[1], p1[1]]) | |
x, y = x.to(device), y.to(device) | |
y_int, slope = Interp1d()(x, y, x_int) | |
new_int_points = torch.cat([x_int.view(x_int.shape[-1], 1), y_int.view(y_int.shape[-1], 1)], dim=1) | |
new_points = torch.cat([new_points, p0.unsqueeze(0)]) | |
new_points = torch.cat([new_points, new_int_points]) | |
# manual padding - torch.nn.functional.pad(value=) does not support 2D tensors as padding values | |
if len(new_points) < K: | |
pad_len = K - len(new_points) | |
last_point = new_points[-1].unsqueeze(0) | |
assert remaining_points > 0 | |
for i in range(remaining_points): | |
new_points = torch.cat([new_points, last_point]) | |
assert len(new_points) == K, (len(new_points), remaining_points) | |
new_points = new_points.unsqueeze(0) | |
return new_points |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment