Last active
February 6, 2018 13:09
-
-
Save psycharo-zz/93f88a1bc07eeeee00e84d66fbf63082 to your computer and use it in GitHub Desktop.
permutohedral lattice filtering in numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def np_splat(inputs, offsets, weights, nbs): | |
N, V = inputs.shape | |
F = weights.shape[1] - 1 | |
M = nbs.shape[0] | |
# splatting | |
## compute inputs multiplied by the weights | |
weighted_inputs = np.matmul(weights[:N,:,np.newaxis], | |
inputs[:N,np.newaxis,:]) | |
weighted_inputs = weighted_inputs.reshape([-1, V]) | |
## sum up at corresponding indices (update with duplicatess) | |
idxs = offsets[:N,:F+1].reshape((-1,))+1 | |
values = np.zeros([M+2, V]) | |
np.add.at(values, idxs, weighted_inputs) | |
return values | |
def np_blur(inputs, values_in, offsets, weights, nbs): | |
N, V = inputs.shape | |
F = weights.shape[1] - 1 | |
M = nbs.shape[0] | |
values = values_in.copy() | |
# NOTE: we actually ignore the last update? | |
for j in range(F+1): | |
n1 = values[nbs[:,j,0]+1] | |
n2 = values[nbs[:,j,1]+1] | |
values[1:-1] += 0.5 * (n1 + n2) | |
return values | |
def np_slice(inputs, values, offsets, weights, nbs): | |
N, V = inputs.shape | |
F = weights.shape[1] - 1 | |
M = nbs.shape[0] | |
alpha = 1.0 / (1.0 + 2.0**(-F)) | |
idxs = offsets[:N,:F+1].reshape((-1,))+1 | |
w = weights[:N,:,np.newaxis] | |
v = values[idxs].reshape((N, F+1, V)) | |
return np.sum(alpha * w * v, axis=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment