Last active
October 5, 2020 17:02
-
-
Save BenedictWilkins/d07d8dcfc0aa2963d6686619fae2d090 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Computes a one-hot encoding of a multi-dimensional numpy array. Uses numpy advanced indexing... | |
If anyone has ideas for a more efficient version please let me know! | |
Example 1: | |
import numpy as np | |
x = np.random.randint(0,3,size=(2,2)) | |
y = onehot(x, (2,2,3)) | |
print(x) | |
print(y) | |
Example 2: | |
import numpy as np | |
x = np.random.randint(0,3,size=(2,2)) | |
y = onehot(x, (2,3,2)) | |
print(x) | |
print(y) | |
""" | |
def onehot(x, shape, dtype=np.uint8): | |
# https://stackoverflow.com/a/46103129/ @Divakar | |
def all_idx(idx, axis): # computes the full index given an multi-dimensional index array | |
grid = np.ogrid[tuple(map(slice, idx.shape))] | |
grid.insert(axis, idx) | |
return tuple(grid) | |
assert len(shape) - len(x.shape) == 1 #one hot should add one more dimension | |
shape, xshape = list(shape), list(x.shape) | |
# find onehot dimension | |
dif = [a == b for i,(a,b) in enumerate(zip(shape, xshape + [-1]))] | |
axis = dif.index(False) # one hot dimension (first miss-match) | |
# validate dimensions | |
xshape.insert(axis, 1) | |
check = [int(a == b) for a,b in zip(shape, xshape)] | |
assert sum(check) == len(shape) - 1 # dimensions should match in all but 1 place | |
# compute one-hot array | |
idx = all_idx(x, axis) | |
r = np.zeros(shape, dtype=dtype) | |
r[idx] = 1 | |
return r |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment