Skip to content

Instantly share code, notes, and snippets.

@mkocabas
Created June 19, 2018 15:48
Show Gist options
  • Save mkocabas/3a5ee93fa60fcba9b80a44ade28aa55b to your computer and use it in GitHub Desktop.
Save mkocabas/3a5ee93fa60fcba9b80a44ade28aa55b to your computer and use it in GitHub Desktop.
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftArgmax(nn.Module):
def __init__(self):
super(SoftArgmax, self).__init__()
def forward(self, x):
if x.dim() == 4:
# x = x / (x.sum(2, keepdim=True).sum(3, keepdim=True) + 1e-8)
# e = torch.exp(x)
# s = e.sum(2, keepdim=True).sum(3, keepdim=True)
# x = e / s
num_batch, num_filters, num_cols, num_rows = x.shape
x = x.view(-1, num_cols * num_rows)
x = F.softmax(x, dim=-1)
x = x.view(num_batch, num_filters, num_cols, num_rows)
W_x = torch.unsqueeze(torch.linspace(0.0, 1.0, int(num_rows)), 0)\
.repeat(num_batch, num_filters, num_cols, 1)
W_y = torch.unsqueeze(torch.linspace(0.0, 1.0, int(num_cols)), 0)\
.repeat(num_batch, num_filters, num_rows, 1)\
.transpose(3, 2)
X = (x * W_x).sum(2).sum(2)
Y = (x * W_y).sum(2).sum(2)
kps = torch.cat([X, Y], dim=-1)
kps = kps.view(-1, num_filters * 2)
return kps
else:
raise ValueError('This function is specific for 4D tensors. '
'Here, ndim=' + str(x.dim()))
if __name__ == '__main__':
import torch.utils.data
h, w = 64, 64
data = torch.zeros([1, 3, h, w])
data[0, 0, 2, 3] = 1.
data[0, 1, 5, 5] = 1.
data[0, 2, 8, 8] = 1.
layer = SoftArgmax()
kps = layer(data)
# kps[:, 0::2] = kps[:, 0::2] * 10 / 2 + 10 / 2
# kps[:, 1::2] = kps[:, 1::2] * 10 / 2 + 10 / 2
# *self.height / 2 + self.height / 2
print(kps.shape)
print(kps * 64.)
# import pose.datasets as datasets
#
# val_loader = torch.utils.data.DataLoader(
# datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images',
# sigma=1.0, label_type='Gaussian', train=False),
# batch_size=2, shuffle=False,
# num_workers=1, pin_memory=True)
#
# for i, (inputs, target, meta) in enumerate(val_loader):
#
# model = SoftArgmax()
# y = model.forward(target.cuda())
#
# print(meta['tpts'])
# print(y * 64.0)
# print(y.shape, meta['tpts'].shape)
#
# if i == 3:
# break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment