-
-
Save jeasinema/1cba9b40451236ba2cfb507687e08834 to your computer and use it in GitHub Desktop.
import torch | |
import torch.nn.functional as F | |
from torch.nn.parameter import Parameter | |
import numpy as np | |
class SpatialSoftmax(torch.nn.Module): | |
def __init__(self, height, width, channel, temperature=None, data_format='NCHW'): | |
super(SpatialSoftmax, self).__init__() | |
self.data_format = data_format | |
self.height = height | |
self.width = width | |
self.channel = channel | |
if temperature: | |
self.temperature = Parameter(torch.ones(1)*temperature) | |
else: | |
self.temperature = 1. | |
pos_x, pos_y = np.meshgrid( | |
np.linspace(-1., 1., self.height), | |
np.linspace(-1., 1., self.width) | |
) | |
pos_x = torch.from_numpy(pos_x.reshape(self.height*self.width)).float() | |
pos_y = torch.from_numpy(pos_y.reshape(self.height*self.width)).float() | |
self.register_buffer('pos_x', pos_x) | |
self.register_buffer('pos_y', pos_y) | |
def forward(self, feature): | |
# Output: | |
# (N, C*2) x_0 y_0 ... | |
if self.data_format == 'NHWC': | |
feature = feature.transpose(1, 3).tranpose(2, 3).view(-1, self.height*self.width) | |
else: | |
feature = feature.view(-1, self.height*self.width) | |
softmax_attention = F.softmax(feature/self.temperature, dim=-1) | |
expected_x = torch.sum(self.pos_x*softmax_attention, dim=1, keepdim=True) | |
expected_y = torch.sum(self.pos_y*softmax_attention, dim=1, keepdim=True) | |
expected_xy = torch.cat([expected_x, expected_y], 1) | |
feature_keypoints = expected_xy.view(-1, self.channel*2) | |
return feature_keypoints | |
if __name__ == '__main__': | |
data = torch.zeros([1,3,3,3]) | |
data[0,0,0,1] = 10 | |
data[0,1,1,1] = 10 | |
data[0,2,1,2] = 10 | |
layer = SpatialSoftmax(3, 3, 3, temperature=1) | |
print(layer(data)) |
Hey I think you are correct, here is a snippet from spatial softmax layer from tensorflow
if temperature is None:
temp_initializer = init_ops.ones_initializer()
else:
temp_initializer = init_ops.constant_initializer(temperature)
It is defined in here https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/layers/python/layers/layers.py.
Hey, all
I am very confused about the return values of spatial softmax layer. The input of the spatial softmax is a (B, C, H, W) feature map and the output is a (B, 2C) tensor.
I guess the output of the spatial softmax should the 2D positions (x, y) of the feature points in the original image. The type should be positive int. However, I find the type of output from the code is float including a lot of negative values.
Do I have a wrong understanding of spatial softmax? Any idea about this?
Many thanks for your help.
Best
Have you figured it out?
@dachengxiaocheng
I think we can multiply the float output with H and W and round to int so that we can get the int coordinate(pixel like coordinate) within the range of (H, W) of CONV output.
And in my opinion, it might be cumbersome to deal with negative numbers, so why don't we just use range of [0,1], not [-1,1]?
I mean,
pos_x, pos_y = np.meshgrid( np.linspace(0., 1., self.height), np.linspace(0., 1., self.width) )
,
Instead of
pos_x, pos_y = np.meshgrid( np.linspace(-1., 1., self.height), np.linspace(-1., 1., self.width) )
.
Thank you for the code snippet! I was referencing the snippet for one of my personal projects!
However, i think there's a minor mismatch / counter intuitive bug :
When 'temperature' parameter is passed in as 'None' the expectation of API should be that temperature is a learnable parameter of the module (the training adjusts or learns the value over time). When a fixed value is passed in to the API, temperature is not a learned value and instead stays as that fixed value even when model is being trained.
Assuming you agree with the above 'requirement' (call it API experience if you prefer :-)) , the above code snippet seems to do exactly the opposite of what you want, instead when a temperature is passed in, it should be a fixed value and when it's not passed it, it should be created as parameter, so the code (match with the similar lines in the snippet) should really be this:
I am new to pytorch as well, so correct me if you think I am wrong.