Created
January 8, 2020 11:21
-
-
Save clemisch/f1462d08591eceeb377b31de4eaa2b9e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as np | |
import numpy as onp | |
def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0): | |
"""Convenience wrapper around slice applying to only one dimension.""" | |
start_indices = [0] * operand.ndim | |
limit_indices = list(operand.shape) | |
strides = [1] * operand.ndim | |
# translate `None` | |
len_axis = operand.shape[axis] | |
start_index = start_index if start_index is not None else 0 | |
limit_index = limit_index if limit_index is not None else len_axis | |
# translate negative indices | |
if start_index < 0: | |
start_index = start_index + len_axis | |
if limit_index < 0: | |
limit_index = limit_index + len_axis | |
axis = int(axis) | |
start_indices[axis] = int(start_index) | |
limit_indices[axis] = int(limit_index) | |
strides[axis] = int(stride) | |
return jax.lax.slice(operand, start_indices, limit_indices, strides) | |
@jax.partial(jax.jit, static_argnums=1) | |
def gradient_along_axis_swapaxes(a, axis): | |
a_swap = np.swapaxes(a, 0, axis) | |
a_grad = np.concatenate(( | |
(a_swap[1] - a_swap[0])[np.newaxis], | |
(a_swap[2:] - a_swap[:-2]) * 0.5, | |
(a_swap[-1] - a_swap[-2])[np.newaxis] | |
), axis=0) | |
return np.swapaxes(a_grad, 0, axis) | |
@jax.partial(jax.jit, static_argnums=1) | |
def gradient_along_axis_sliced(a, axis): | |
sliced = jax.partial(slice_in_dim, a, axis=axis) | |
a_grad = np.concatenate(( | |
sliced(1, 2) - sliced(0, 1), | |
(sliced(2, None) - sliced(0, -2)) * 0.5, | |
sliced(-1, None) - sliced(-2, -1), | |
), axis) | |
return a_grad | |
@jax.jit | |
def gradient_swap(a): | |
a_grad = [gradient_along_axis_swapaxes(a, ax) for ax in range(a.ndim)] | |
return a_grad | |
@jax.jit | |
def gradient_sliced(a): | |
a_grad = [gradient_along_axis_sliced(a, ax) for ax in range(a.ndim)] | |
return a_grad | |
key = jax.random.PRNGKey(0) | |
x = jax.random.normal(key, (100, 100, 100)) | |
onp.testing.assert_allclose(onp.gradient(x), gradient_swap(x)) | |
onp.testing.assert_allclose(onp.gradient(x), gradient_sliced(x)) | |
%timeit jax.device_get(gradient_swap(x)) | |
%timeit jax.device_get(gradient_sliced(x)) | |
# CPU (i7 8550U) | |
# 14.3 ms ± 2.2 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
# 8.89 ms ± 274 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
# GPU (GTX 1080 Ti) | |
# 6.35 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
# 4.69 ms ± 54 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment