Skip to content

Instantly share code, notes, and snippets.

@qianyizhang
Last active June 6, 2018 22:48
Show Gist options
  • Save qianyizhang/07ee1c15cad08afb03f5de69349efc30 to your computer and use it in GitHub Desktop.
Save qianyizhang/07ee1c15cad08afb03f5de69349efc30 to your computer and use it in GitHub Desktop.
numpy one_hot function
import numpy as np
def one_hot(nparray, depth = 0, on_value = 1, off_value = 0):
if depth == 0:
depth = np.max(nparray) + 1
assert np.max(nparray) < depth, "the max index of nparray: {} is larger than depth: {}".format(np.max(nparray), depth)
shape = nparray.shape
out = np.ones((*shape, depth)) * off_value
indices = []
for i in range(nparray.ndim):
tiles = [1] * nparray.ndim
s = [1] * nparray.ndim
s[i] = -1
r = np.arange(shape[i]).reshape(s)
if i > 0:
tiles[i-1] = shape[i-1]
r = np.tile(r, tiles)
indices.append(r)
indices.append(nparray)
out[tuple(indices)] = on_value
return out
def test_one_hot():
a = np.array([1,2,3],[4,5,6])
# array([[[ 0., 1., 0., 0., 0., 0., 0.],
# [ 0., 0., 1., 0., 0., 0., 0.],
# [ 0., 0., 0., 1., 0., 0., 0.]],
#
# [[ 0., 0., 0., 0., 1., 0., 0.],
# [ 0., 0., 0., 0., 0., 1., 0.],
# [ 0., 0., 0., 0., 0., 0., 1.]]])
one_hot(a)
@xychenunc
Copy link

well done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment