Skip to content

Instantly share code, notes, and snippets.

@mattjj
Last active February 14, 2019 01:19
Show Gist options
  • Save mattjj/12f2745a7bcfe2a4393a24c7d76be9e8 to your computer and use it in GitHub Desktop.
Save mattjj/12f2745a7bcfe2a4393a24c7d76be9e8 to your computer and use it in GitHub Desktop.
### numpy version
import numpy as onp
x = onp.zeros((10, 2))
x[3:5] = 5.
print x
# [[0. 0.]
# [0. 0.]
# [0. 0.]
# [5. 5.]
# [5. 5.]
# [0. 0.]
# [0. 0.]
# [0. 0.]
# [0. 0.]
# [0. 0.]]
x[7:10, 1:2] = 9.
print x
# [[0. 0.]
# [0. 0.]
# [0. 0.]
# [5. 5.]
# [5. 5.]
# [0. 0.]
# [0. 0.]
# [0. 9.]
# [0. 9.]
# [0. 9.]]
### jax/lax version
import jax.numpy as np
from jax import lax
x = np.zeros((10, 2))
# this part is like x[3:5] = 5.
update = np.array([[5., 5.],
[5., 5.]])
x = lax.dynamic_update_slice(x, update, [3, 0])
print x
# [[0. 0.]
# [0. 0.]
# [0. 0.]
# [5. 5.]
# [5. 5.]
# [0. 0.]
# [0. 0.]
# [0. 0.]
# [0. 0.]
# [0. 0.]]
# this part is like x[7:10, 1:2] = 9
update = np.array([[9.],
[9.]])
x = lax.dynamic_update_slice(x, update, [7, 1])
print x
# [[0. 0.]
# [0. 0.]
# [0. 0.]
# [5. 5.]
# [5. 5.]
# [0. 0.]
# [0. 0.]
# [0. 9.]
# [0. 9.]
# [0. 9.]]
# jit works too
from jax import jit
x = np.zeros((10, 2))
@jit
def first_update(x, start, val):
update = np.array([[val, val],
[val, val]])
y = lax.dynamic_update_slice(x, update, [start, 0])
return y
x = first_update(x, 3, 5.)
print x
# [[0. 0.]
# [0. 0.]
# [0. 0.]
# [5. 5.]
# [5. 5.]
# [0. 0.]
# [0. 0.]
# [0. 0.]
# [0. 0.]
# [0. 0.]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment