Skip to content

Instantly share code, notes, and snippets.

@wkentaro
Created June 15, 2020 15:37
Show Gist options
  • Save wkentaro/6c7e4538f2d7b5ff968101b76e6bcc08 to your computer and use it in GitHub Desktop.
Save wkentaro/6c7e4538f2d7b5ff968101b76e6bcc08 to your computer and use it in GitHub Desktop.
import math
import imgviz
import numpy as np
import torch
class PositionalEmbeddingSinCos(torch.nn.Module):
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self._num_pos_feats = num_pos_feats
self._temperature = temperature
self._normalize = normalize
self._scale = scale
def forward(self, input):
batch_size, _, height, width = input.shape
mask = torch.ones((batch_size, height, width), dtype=torch.bool)
y_embed = mask.cumsum(dim=1, dtype=torch.float32)
x_embed = mask.cumsum(dim=2, dtype=torch.float32)
if self._normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self._scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self._scale
dim_t = torch.arange(
self._num_pos_feats, dtype=torch.float32, device=input.device
)
dim_t = self._temperature ** (2 * (dim_t // 2) / self._num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionalEmbedding(torch.nn.Module):
def __init__(self, height, width, num_pos_feats=64):
super().__init__()
self.embed_y = torch.nn.Embedding(
num_embeddings=height, embedding_dim=num_pos_feats
)
self.embed_x = torch.nn.Embedding(
num_embeddings=width, embedding_dim=num_pos_feats
)
def forward(self, input):
batch_size, _, height, width = input.shape
y_embed = self.embed_y(torch.arange(height))
y_embed = y_embed[:, None].repeat_interleave(width, dim=1)
x_embed = self.embed_x(torch.arange(width))
x_embed = x_embed[None, :].repeat_interleave(height, dim=0)
pos = torch.cat([y_embed, x_embed], dim=2)
pos = pos.permute(2, 0, 1)
pos = pos[None].repeat_interleave(batch_size, dim=0)
return pos
def main():
img = imgviz.io.imread("./data/glasses.jpg")
x = img.transpose(2, 0, 1)[None]
x = torch.as_tensor(x, dtype=torch.float32)
x = x / 255 * 2 - 1
height, width = img.shape[:2]
x_pos = PositionalEmbeddingSinCos()(x)
# x_pos = PositionalEmbedding(height=x.shape[2], width=x.shape[3])(x)
np.random.seed(0)
x_pos_viz = imgviz.nchannel2rgb(
x_pos[0].detach().numpy().transpose(1, 2, 0)
)
viz = imgviz.tile([img, x_pos_viz])
imgviz.io.pyglet_imshow(viz)
imgviz.io.pyglet_run()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment