Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Last active April 23, 2022 17:40
Show Gist options
  • Save Mason-McGough/8256fe40eeddf08b5dc1c088f07c9b0f to your computer and use it in GitHub Desktop.
Save Mason-McGough/8256fe40eeddf08b5dc1c088f07c9b0f to your computer and use it in GitHub Desktop.
Standard positional encoder for NeRF
class PositionalEncoder(nn.Module):
r"""
Sine-cosine positional encoder for input points.
"""
def __init__(
self,
d_input: int,
n_freqs: int,
log_space: bool = False
):
super().__init__()
self.d_input = d_input
self.n_freqs = n_freqs
self.log_space = log_space
self.d_output = d_input * (1 + 2 * self.n_freqs)
self.embed_fns = [lambda x: x]
# Define frequencies in either linear or log scale
if self.log_space:
freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs)
# Alternate sin and cos
for freq in freq_bands:
self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))
def forward(
self,
x
) -> torch.Tensor:
r"""
Apply positional encoding to input.
"""
return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment