Last active
December 11, 2023 13:22
-
-
Save maedoc/a90915acb101cf8ab60a3e321b2ffb08 to your computer and use it in GitHub Desktop.
Compare visiting N-dimensions with Jax scans vs pre-generated indices with NumPy meshgrid
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
# use cpu for this | |
os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
########################### numpy style | |
import numpy as np | |
n = 5 | |
size = 4 | |
basis = np.r_[:n] | |
bases = [basis]*size | |
multiplets = np.array(np.meshgrid(*bases)).reshape(size, -1) | |
print('shape', multiplets.shape, multiplets.nbytes >> 10, 'kB') | |
############################ numpy style 2 | |
#idx = np.r_[:n**size] | |
#mp2 = np.array(np.unravel_index(idx, (n,)*size))[::-1] | |
#np.testing.assert_equal(mp2, multiplets) | |
############################ numba style | |
import numba | |
@numba.njit | |
def _inplace_make_multiplets(idx, n, size, _i, _shape, _shape2): | |
idx_ = idx.reshape((-1, size)) | |
for i in range(idx_.shape[0]): | |
_i[:] = 0 | |
_shape2 = _shape | |
for j, d in enumerate(shape[::-1]): | |
_shape2, | |
# idx_[i] = np.array(np.unravel_index(i, shape)) | |
""" | |
idx = np.zeros((n,)*size + (size,), 'i') | |
_shape = np.array(idx.shape[:-1]) | |
_shape2 = np.array(idx.shape[:-1]) | |
_i = np.zeros(size, 'i') | |
_inplace_make_multiplets(idx, n, size) | |
idx = idx.T[::-1] | |
assert idx.shape == multiplets.shape | |
np.testing.assert_equal(idx, multiplets) | |
1/0 | |
""" | |
############################ jax style | |
import jax | |
import jax.numpy as jp | |
# let's visit all elements of the n^size grid with scans | |
def make_op(n, size): | |
multiplet_dims = jp.array([n for i in range(size)]) | |
def op(carry, x): | |
idx, count = carry | |
# count how many grid points visited | |
count = count + 1 | |
# next idx is just unravel count w/ shape of multiplet grid | |
idx = jp.array(jp.unravel_index(count, multiplet_dims)) | |
return (idx, count), None | |
return op | |
op = make_op(n, size) | |
idx = jp.zeros((size,), 'i') | |
carry = idx, 0 | |
# visiting a single dim is one scan | |
(idx, count), _ = jax.lax.scan(op, carry, basis) | |
assert count == n | |
assert np.allclose(idx, jp.r_[0, 0, 1, 0]) | |
(idx, count), _ = jax.lax.scan(op, (idx,count), basis) | |
assert count == 2*n, count | |
assert np.allclose(idx, jp.r_[0, 0, 2, 0]) | |
# one scan per dim | |
(idx, count), _ = jax.lax.scan( | |
lambda c,x: jax.lax.scan(op, c, basis), | |
carry, basis) | |
assert count == n**2, count | |
assert np.allclose(idx, jp.r_[0, 1, 0, 0]) | |
# N dim (⌐⊙_⊙) | |
def rec_scan(dim, op): | |
assert dim > 0 | |
_op = op if dim == 1 else rec_scan(dim - 1, op) | |
return lambda c,x: jax.lax.scan(_op, c, basis, unroll=10 if dim <= 2 else 1) | |
visitor = jax.jit(rec_scan(size, op)) | |
(idx, count), _ = visitor(carry, None) | |
assert count == n**size, count | |
assert np.allclose(idx, jp.ones(size)*(n-1)) | |
############################ jax style, flat | |
def make_op2(n, size): | |
multiplet_dims = jp.array([n for i in range(size)]) | |
def op(carry, x): | |
idx, count = carry | |
# count how many grid points visited | |
count = count + 1 | |
# next idx is just unravel count w/ shape of multiplet grid | |
idx = jp.array(jp.unravel_index(count, multiplet_dims)) | |
return (idx, count), None | |
return op | |
op2 = make_op2(n, size) | |
idx2 = jp.zeros((size,), 'i') | |
carry = idx, 0 | |
visitor2 = jax.jit(lambda c: jax.lax.scan(op2, c, jp.r_[:n**size], unroll=10)) | |
(idx2, count2), _ = visitor2(carry) | |
assert count2 == n**size, count | |
assert np.allclose(idx2, jp.ones(size)*(n-1)) | |
# time it | |
import time | |
for n in [5, 25]: | |
for size in [3, 4, 5]: | |
# generate indices | |
basis = np.r_[:n] | |
bases = [basis]*size | |
tik = time.time() | |
for i in range(100): | |
multiplets = np.array(np.meshgrid(*bases)).reshape(size, -1) | |
tiktok = time.time() - tik | |
print(f'{n} {size} meshgrid {1e3*tiktok/100:0.3f} ms/') | |
# scan visitor | |
visitor = jax.jit(rec_scan(size, make_op(n, size))) | |
carry = jp.zeros((size, ), 'i'), 0 | |
visitor(carry, None) | |
tik = time.time() | |
for i in range(100): | |
(idx, count), _ = visitor(carry, None) | |
assert count == multiplets.shape[1], (count, multiplets.shape) | |
tiktok = time.time() - tik | |
print(f'{n} {size} jax visitor {1e3*tiktok/100:0.3f} ms/') | |
# scan visitor flat | |
op2 = make_op2(n, size) | |
idx2 = jp.zeros((size,), 'i') | |
carry = idx, 0 | |
visitor2 = jax.jit(lambda c: jax.lax.scan(op2, c, jp.r_[:n**size], unroll=10)) | |
(idx2, count2), _ = visitor2(carry) | |
tik = time.time() | |
for i in range(100): | |
(idx2, count2), _ = visitor2(carry) | |
assert count2 == multiplets.shape[1], (count2, multiplets.shape) | |
tiktok = time.time() - tik | |
print(f'{n} {size} jax flat {1e3*tiktok/100:0.3f} ms/') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
output like so on a 2018 xeon
the flat iteration has marginal impact