Skip to content

Instantly share code, notes, and snippets.

@psycharo-zz
Last active February 6, 2018 13:09
Show Gist options
  • Save psycharo-zz/93f88a1bc07eeeee00e84d66fbf63082 to your computer and use it in GitHub Desktop.
Save psycharo-zz/93f88a1bc07eeeee00e84d66fbf63082 to your computer and use it in GitHub Desktop.
permutohedral lattice filtering in numpy
def np_splat(inputs, offsets, weights, nbs):
N, V = inputs.shape
F = weights.shape[1] - 1
M = nbs.shape[0]
# splatting
## compute inputs multiplied by the weights
weighted_inputs = np.matmul(weights[:N,:,np.newaxis],
inputs[:N,np.newaxis,:])
weighted_inputs = weighted_inputs.reshape([-1, V])
## sum up at corresponding indices (update with duplicatess)
idxs = offsets[:N,:F+1].reshape((-1,))+1
values = np.zeros([M+2, V])
np.add.at(values, idxs, weighted_inputs)
return values
def np_blur(inputs, values_in, offsets, weights, nbs):
N, V = inputs.shape
F = weights.shape[1] - 1
M = nbs.shape[0]
values = values_in.copy()
# NOTE: we actually ignore the last update?
for j in range(F+1):
n1 = values[nbs[:,j,0]+1]
n2 = values[nbs[:,j,1]+1]
values[1:-1] += 0.5 * (n1 + n2)
return values
def np_slice(inputs, values, offsets, weights, nbs):
N, V = inputs.shape
F = weights.shape[1] - 1
M = nbs.shape[0]
alpha = 1.0 / (1.0 + 2.0**(-F))
idxs = offsets[:N,:F+1].reshape((-1,))+1
w = weights[:N,:,np.newaxis]
v = values[idxs].reshape((N, F+1, V))
return np.sum(alpha * w * v, axis=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment