Skip to content

Instantly share code, notes, and snippets.

@benanne
Created April 4, 2012 11:52
Show Gist options
  • Select an option

  • Save benanne/2300591 to your computer and use it in GitHub Desktop.

Select an option

Save benanne/2300591 to your computer and use it in GitHub Desktop.
one hot + maximum mask in theano
import theano
import theano.tensor as T
def one_hot(t, r=None):
"""
given a tensor t of dimension d with integer values from range(r), return a
new tensor of dimension d + 1 with values 0/1, where the last dimension
gives a one-hot representation of the values in t.
if r is not given, r is set to max(t) + 1
"""
if r is None:
r = T.max(t) + 1
ranges = T.shape_padleft(T.arange(r), t.ndim)
return T.eq(ranges, T.shape_padright(t, 1))
def max_mask(t, axis):
"""
given a tensor t and an axis, returns a mask tensor of the same size which is
1 where the tensor has a maximum along the given axis, and 0 elsewhere.
"""
a = T.argmax(t, axis=axis)
a_oh = one_hot(a, t.shape[axis])
# we want the 'one hot' dimension in the same position as the axis over
# which we took the argmax. This takes some dimshuffle trickery:
reordered_dims = range(axis) + [a_oh.ndim - 1] + range(axis, a_oh.ndim - 1)
return a_oh.dimshuffle(reordered_dims)
# TODO: generalise this to multiple axes
if __name__ == '__main__':
import numpy as np
# test one_hot
a = np.array([0,1,2,3,4,5], dtype=theano.config.floatX)
b = np.array([9,2,0,7,4,5,1], dtype=theano.config.floatX)
x = T.vector('x')
f1 = theano.function([x], one_hot(x))
af1 = f1(a)
bf1 = f1(b)
assert af1.shape == (6,6)
assert bf1.shape == (7,10)
print af1
print bf1
# test max_mask
a = np.array([[2,3,1],[5,0,2]], dtype=theano.config.floatX)
y = T.matrix('y')
f2 = theano.function([y], max_mask(y, 0))
f3 = theano.function([y], max_mask(y, 1))
print a
print f2(a)
print f3(a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment