Created
February 21, 2024 08:27
-
-
Save Sam-Izdat/8c4909c4e8a0e401f6a9fec30ca015c0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@classmethod | |
def normals_to_height(cls, | |
normal_map:torch.Tensor, | |
self_tiling:bool=False, | |
rescaled:bool=False, | |
eps:float=torch.finfo(torch.float32).eps) -> (torch.Tensor, torch.Tensor): | |
""" | |
Compute height from normals. Frankot-Chellappa algorithm. | |
:param normal_map: Normal map tensor sized [N, C=3, H, W] or [C=3, H, W] | |
as unit vectors of surface normals. | |
:param self_tiling: Treat surface as self-tiling. | |
:param rescaled: Accept unit vector tensor in [0, 1] value range. | |
:return: Height tensor sized [N, C=1, H, W] or [C=1, H, W] in [0, 1] range | |
and height scale tensor sized [N, C=1] or [C=1] in [0, inf] range. | |
""" | |
ndim = len(normal_map.size()) | |
assert ndim == 3 or ndim == 4, cls.err_size | |
nobatch = ndim == 3 | |
if nobatch: normal_map = normal_map.unsqueeze(0) | |
assert normal_map.size(1) == 3, cls.err_normal_ch | |
if rescaled: normal_map = normal_map * 2. - 1. | |
device = normal_map.device | |
N, _, H, W = normal_map.size() | |
res_disp, res_scale = [], [] | |
for i in range(N): | |
vec = normal_map[i] | |
nx, ny = vec[0], vec[1] | |
if not self_tiling: | |
nxt = torch.cat([nx, -torch.flip(nx, dims=[1])], dim=1) | |
nxb = torch.cat([torch.flip(nx, dims=[0]), -torch.flip(nx, dims=[0,1])], dim=1) | |
nx = torch.cat([nxt, nxb], dim=0) | |
nyt = torch.cat([ny, torch.flip(ny, dims=[1])], dim=1) | |
nyb = torch.cat([-torch.flip(ny, dims=[0]), -torch.flip(ny, dims=[0,1])], dim=1) | |
ny = torch.cat([nyt, nyb], dim=0) | |
r, c = nx.shape | |
rg = (torch.arange(r) - (r // 2 + 1)).float() / (r - r % 2) | |
cg = (torch.arange(c) - (c // 2 + 1)).float() / (c - c % 2) | |
u, v = torch.meshgrid(cg, rg, indexing='xy') | |
u = torch.fft.ifftshift(u.to(device)) | |
v = torch.fft.ifftshift(v.to(device)) | |
gx = torch.fft.fft2(-nx) | |
gy = torch.fft.fft2(ny) | |
nom = (-1j * u * gx) + (-1j * v * gy) | |
denom = (u**2) + (v**2) + eps | |
zf = nom / denom | |
zf[0, 0] = 0.0 | |
z = torch.real(torch.fft.ifft2(zf)) | |
disp, scale = (z - torch.min(z)) / (torch.max(z) - torch.min(z)), float(torch.max(z) - torch.min(z)) | |
if not self_tiling: disp = disp[:H, :W] | |
res_disp.append(disp.unsqueeze(0).unsqueeze(0)) | |
res_scale.append(torch.tensor(scale).unsqueeze(0)) | |
res_disp = torch.cat(res_disp, dim=0) | |
res_scale = torch.cat(res_scale, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(res_disp.device) | |
if nobatch: | |
res_disp = res_disp.squeeze(0) | |
res_scale = res_scale.squeeze(0) | |
return res_disp, res_scale / 10. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment