Skip to content

Instantly share code, notes, and snippets.

@NegatioN
Last active April 16, 2020 09:51
Show Gist options
  • Save NegatioN/d46f17caa47535fa4940c727d1fd537d to your computer and use it in GitHub Desktop.
Save NegatioN/d46f17caa47535fa4940c727d1fd537d to your computer and use it in GitHub Desktop.
Pytorch N-hot encode function
def n_hot(y, num_classes, scatter_dim):
# we assume the masking-value is always -1
# add extra class and shift y's
nc = num_classes+1
y+= 1
y_tensor = y.view(*y.size()[:scatter_dim], -1)
zeros = torch.zeros(*y.size()[:scatter_dim], nc, dtype=y.dtype, device=y.device)
res = zeros.scatter(scatter_dim, y_tensor, 1)
return res.index_select(scatter_dim, torch.arange(1, nc).long())
n_hot(inp, 10, 0)
t = torch.as_tensor([[1, 3], [2, 4]])
inp = [torch.as_tensor(x) for x in [[3, 7, 9], [3], [9, 1, 4]]]
inp = torch.nn.utils.rnn.pad_sequence(inp, batch_first=True, padding_value=-1)
n_hot(t, num_classes=10, scatter_dim=1)
n_hot(inp, num_classes=10, scatter_dim=1)
#Output
'''
tensor([[0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0, 0, 0, 0]])
tensor([[0, 0, 0, 1, 0, 0, 0, 1, 0, 1],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 1, 0, 0, 0, 0, 1]])
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment