Created
June 15, 2020 15:37
-
-
Save wkentaro/6c7e4538f2d7b5ff968101b76e6bcc08 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import imgviz | |
| import numpy as np | |
| import torch | |
| class PositionalEmbeddingSinCos(torch.nn.Module): | |
| def __init__( | |
| self, num_pos_feats=64, temperature=10000, normalize=False, scale=None | |
| ): | |
| super().__init__() | |
| if scale is not None and normalize is False: | |
| raise ValueError("normalize should be True if scale is passed") | |
| if scale is None: | |
| scale = 2 * math.pi | |
| self._num_pos_feats = num_pos_feats | |
| self._temperature = temperature | |
| self._normalize = normalize | |
| self._scale = scale | |
| def forward(self, input): | |
| batch_size, _, height, width = input.shape | |
| mask = torch.ones((batch_size, height, width), dtype=torch.bool) | |
| y_embed = mask.cumsum(dim=1, dtype=torch.float32) | |
| x_embed = mask.cumsum(dim=2, dtype=torch.float32) | |
| if self._normalize: | |
| eps = 1e-6 | |
| y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self._scale | |
| x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self._scale | |
| dim_t = torch.arange( | |
| self._num_pos_feats, dtype=torch.float32, device=input.device | |
| ) | |
| dim_t = self._temperature ** (2 * (dim_t // 2) / self._num_pos_feats) | |
| pos_x = x_embed[:, :, :, None] / dim_t | |
| pos_y = y_embed[:, :, :, None] / dim_t | |
| pos_x = torch.stack( | |
| (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos_y = torch.stack( | |
| (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) | |
| return pos | |
| class PositionalEmbedding(torch.nn.Module): | |
| def __init__(self, height, width, num_pos_feats=64): | |
| super().__init__() | |
| self.embed_y = torch.nn.Embedding( | |
| num_embeddings=height, embedding_dim=num_pos_feats | |
| ) | |
| self.embed_x = torch.nn.Embedding( | |
| num_embeddings=width, embedding_dim=num_pos_feats | |
| ) | |
| def forward(self, input): | |
| batch_size, _, height, width = input.shape | |
| y_embed = self.embed_y(torch.arange(height)) | |
| y_embed = y_embed[:, None].repeat_interleave(width, dim=1) | |
| x_embed = self.embed_x(torch.arange(width)) | |
| x_embed = x_embed[None, :].repeat_interleave(height, dim=0) | |
| pos = torch.cat([y_embed, x_embed], dim=2) | |
| pos = pos.permute(2, 0, 1) | |
| pos = pos[None].repeat_interleave(batch_size, dim=0) | |
| return pos | |
| def main(): | |
| img = imgviz.io.imread("./data/glasses.jpg") | |
| x = img.transpose(2, 0, 1)[None] | |
| x = torch.as_tensor(x, dtype=torch.float32) | |
| x = x / 255 * 2 - 1 | |
| height, width = img.shape[:2] | |
| x_pos = PositionalEmbeddingSinCos()(x) | |
| # x_pos = PositionalEmbedding(height=x.shape[2], width=x.shape[3])(x) | |
| np.random.seed(0) | |
| x_pos_viz = imgviz.nchannel2rgb( | |
| x_pos[0].detach().numpy().transpose(1, 2, 0) | |
| ) | |
| viz = imgviz.tile([img, x_pos_viz]) | |
| imgviz.io.pyglet_imshow(viz) | |
| imgviz.io.pyglet_run() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment