Created
April 23, 2022 17:47
-
-
Save Mason-McGough/29c7535d7bee24f5005c2e554239e4de to your computer and use it in GitHub Desktop.
Stratified sampling for NeRF
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sample_stratified( | |
rays_o: torch.Tensor, | |
rays_d: torch.Tensor, | |
near: float, | |
far: float, | |
n_samples: int, | |
perturb: Optional[bool] = True, | |
inverse_depth: bool = False | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
r""" | |
Sample along ray from regularly-spaced bins. | |
""" | |
# Grab samples for space integration along ray | |
t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device) | |
if not inverse_depth: | |
# Sample linearly between `near` and `far` | |
z_vals = near * (1.-t_vals) + far * (t_vals) | |
else: | |
# Sample linearly in inverse depth (disparity) | |
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) | |
# Draw uniform samples from bins along ray | |
if perturb: | |
mids = .5 * (z_vals[1:] + z_vals[:-1]) | |
upper = torch.concat([mids, z_vals[-1:]], dim=-1) | |
lower = torch.concat([z_vals[:1], mids], dim=-1) | |
t_rand = torch.rand([n_samples], device=z_vals.device) | |
z_vals = lower + (upper - lower) * t_rand | |
z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples]) | |
# Apply scale from `rays_d` and offset from `rays_o` to samples | |
# pts: (width, height, n_samples, 3) | |
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] | |
return pts, z_vals |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment