Created
April 23, 2022 17:50
-
-
Save Mason-McGough/f96c2704562593c5af2d79cae320a2e2 to your computer and use it in GitHub Desktop.
Hierarchical sampling for NeRF
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sample_hierarchical( | |
rays_o: torch.Tensor, | |
rays_d: torch.Tensor, | |
z_vals: torch.Tensor, | |
weights: torch.Tensor, | |
n_samples: int, | |
perturb: bool = False | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
r""" | |
Apply hierarchical sampling to the rays. | |
""" | |
# Draw samples from PDF using z_vals as bins and weights as probabilities. | |
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) | |
new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples, | |
perturb=perturb) | |
new_z_samples = new_z_samples.detach() | |
# Resample points from ray based on PDF. | |
z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1) | |
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None] # [N_rays, N_samples + n_samples, 3] | |
return pts, z_vals_combined, new_z_samples | |
def sample_pdf( | |
bins: torch.Tensor, | |
weights: torch.Tensor, | |
n_samples: int, | |
perturb: bool = False | |
) -> torch.Tensor: | |
r""" | |
Apply inverse transform sampling to a weighted set of points. | |
""" | |
# Normalize weights to get PDF. | |
pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True) # [n_rays, weights.shape[-1]] | |
# Convert PDF to CDF. | |
cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]] | |
cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # [n_rays, weights.shape[-1] + 1] | |
# Take sample positions to grab from CDF. Linear when perturb == 0. | |
if not perturb: | |
u = torch.linspace(0., 1., n_samples, device=cdf.device) | |
u = u.expand(list(cdf.shape[:-1]) + [n_samples]) # [n_rays, n_samples] | |
else: | |
u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device) # [n_rays, n_samples] | |
# Find indices along CDF where values in u would be placed. | |
u = u.contiguous() # Returns contiguous tensor with same values. | |
inds = torch.searchsorted(cdf, u, right=True) # [n_rays, n_samples] | |
# Clamp indices that are out of bounds. | |
below = torch.clamp(inds - 1, min=0) | |
above = torch.clamp(inds, max=cdf.shape[-1] - 1) | |
inds_g = torch.stack([below, above], dim=-1) # [n_rays, n_samples, 2] | |
# Sample from cdf and the corresponding bin centers. | |
matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]] | |
cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1, | |
index=inds_g) | |
bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1, | |
index=inds_g) | |
# Convert samples to ray length. | |
denom = (cdf_g[..., 1] - cdf_g[..., 0]) | |
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) | |
t = (u - cdf_g[..., 0]) / denom | |
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) | |
return samples # [n_rays, n_samples] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment