Skip to content

Instantly share code, notes, and snippets.

@jeasinema
Last active March 10, 2023 08:11
Show Gist options
  • Save jeasinema/1cba9b40451236ba2cfb507687e08834 to your computer and use it in GitHub Desktop.
Save jeasinema/1cba9b40451236ba2cfb507687e08834 to your computer and use it in GitHub Desktop.
Spatial(Arg)Softmax for pytorch
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
class SpatialSoftmax(torch.nn.Module):
def __init__(self, height, width, channel, temperature=None, data_format='NCHW'):
super(SpatialSoftmax, self).__init__()
self.data_format = data_format
self.height = height
self.width = width
self.channel = channel
if temperature:
self.temperature = Parameter(torch.ones(1)*temperature)
else:
self.temperature = 1.
pos_x, pos_y = np.meshgrid(
np.linspace(-1., 1., self.height),
np.linspace(-1., 1., self.width)
)
pos_x = torch.from_numpy(pos_x.reshape(self.height*self.width)).float()
pos_y = torch.from_numpy(pos_y.reshape(self.height*self.width)).float()
self.register_buffer('pos_x', pos_x)
self.register_buffer('pos_y', pos_y)
def forward(self, feature):
# Output:
# (N, C*2) x_0 y_0 ...
if self.data_format == 'NHWC':
feature = feature.transpose(1, 3).tranpose(2, 3).view(-1, self.height*self.width)
else:
feature = feature.view(-1, self.height*self.width)
softmax_attention = F.softmax(feature/self.temperature, dim=-1)
expected_x = torch.sum(self.pos_x*softmax_attention, dim=1, keepdim=True)
expected_y = torch.sum(self.pos_y*softmax_attention, dim=1, keepdim=True)
expected_xy = torch.cat([expected_x, expected_y], 1)
feature_keypoints = expected_xy.view(-1, self.channel*2)
return feature_keypoints
if __name__ == '__main__':
data = torch.zeros([1,3,3,3])
data[0,0,0,1] = 10
data[0,1,1,1] = 10
data[0,2,1,2] = 10
layer = SpatialSoftmax(3, 3, 3, temperature=1)
print(layer(data))
@PradeepKadubandi
Copy link

PradeepKadubandi commented Nov 22, 2019

Thank you for the code snippet! I was referencing the snippet for one of my personal projects!

However, i think there's a minor mismatch / counter intuitive bug :
When 'temperature' parameter is passed in as 'None' the expectation of API should be that temperature is a learnable parameter of the module (the training adjusts or learns the value over time). When a fixed value is passed in to the API, temperature is not a learned value and instead stays as that fixed value even when model is being trained.

Assuming you agree with the above 'requirement' (call it API experience if you prefer :-)) , the above code snippet seems to do exactly the opposite of what you want, instead when a temperature is passed in, it should be a fixed value and when it's not passed it, it should be created as parameter, so the code (match with the similar lines in the snippet) should really be this:

if temperature:  
    self.temperature = torch.ones(1)*temperature   
else:   
    self.temperature = Parameter(torch.ones(1))   

I am new to pytorch as well, so correct me if you think I am wrong.

@hai-h-nguyen
Copy link

hai-h-nguyen commented Nov 25, 2019

Hey I think you are correct, here is a snippet from spatial softmax layer from tensorflow

      if temperature is None:
        temp_initializer = init_ops.ones_initializer()
      else:
        temp_initializer = init_ops.constant_initializer(temperature)

It is defined in here https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/layers/python/layers/layers.py.

@dachengxiaocheng
Copy link

dachengxiaocheng commented Feb 21, 2020

Hey, all

I am very confused about the return values of spatial softmax layer. The input of the spatial softmax is a (B, C, H, W) feature map and the output is a (B, 2C) tensor.

I guess the output of the spatial softmax should the 2D positions (x, y) of the feature points in the original image. The type should be positive int. However, I find the type of output from the code is float including a lot of negative values.

Do I have a wrong understanding of spatial softmax? Any idea about this?

Many thanks for your help.

Best

@ZhihaoAIRobotic
Copy link

Have you figured it out?

@heyzude
Copy link

heyzude commented Jun 18, 2020

@dachengxiaocheng
I think we can multiply the float output with H and W and round to int so that we can get the int coordinate(pixel like coordinate) within the range of (H, W) of CONV output.

And in my opinion, it might be cumbersome to deal with negative numbers, so why don't we just use range of [0,1], not [-1,1]?

I mean,
pos_x, pos_y = np.meshgrid( np.linspace(0., 1., self.height), np.linspace(0., 1., self.width) ) ,
Instead of
pos_x, pos_y = np.meshgrid( np.linspace(-1., 1., self.height), np.linspace(-1., 1., self.width) )
.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment