Skip to content

Instantly share code, notes, and snippets.

@maedoc
Last active December 11, 2023 13:22
Show Gist options
  • Save maedoc/a90915acb101cf8ab60a3e321b2ffb08 to your computer and use it in GitHub Desktop.
Save maedoc/a90915acb101cf8ab60a3e321b2ffb08 to your computer and use it in GitHub Desktop.
Compare visiting N-dimensions with Jax scans vs pre-generated indices with NumPy meshgrid
import os
# use cpu for this
os.environ['CUDA_VISIBLE_DEVICES'] = ''
########################### numpy style
import numpy as np
n = 5
size = 4
basis = np.r_[:n]
bases = [basis]*size
multiplets = np.array(np.meshgrid(*bases)).reshape(size, -1)
print('shape', multiplets.shape, multiplets.nbytes >> 10, 'kB')
############################ numpy style 2
#idx = np.r_[:n**size]
#mp2 = np.array(np.unravel_index(idx, (n,)*size))[::-1]
#np.testing.assert_equal(mp2, multiplets)
############################ numba style
import numba
@numba.njit
def _inplace_make_multiplets(idx, n, size, _i, _shape, _shape2):
idx_ = idx.reshape((-1, size))
for i in range(idx_.shape[0]):
_i[:] = 0
_shape2 = _shape
for j, d in enumerate(shape[::-1]):
_shape2,
# idx_[i] = np.array(np.unravel_index(i, shape))
"""
idx = np.zeros((n,)*size + (size,), 'i')
_shape = np.array(idx.shape[:-1])
_shape2 = np.array(idx.shape[:-1])
_i = np.zeros(size, 'i')
_inplace_make_multiplets(idx, n, size)
idx = idx.T[::-1]
assert idx.shape == multiplets.shape
np.testing.assert_equal(idx, multiplets)
1/0
"""
############################ jax style
import jax
import jax.numpy as jp
# let's visit all elements of the n^size grid with scans
def make_op(n, size):
multiplet_dims = jp.array([n for i in range(size)])
def op(carry, x):
idx, count = carry
# count how many grid points visited
count = count + 1
# next idx is just unravel count w/ shape of multiplet grid
idx = jp.array(jp.unravel_index(count, multiplet_dims))
return (idx, count), None
return op
op = make_op(n, size)
idx = jp.zeros((size,), 'i')
carry = idx, 0
# visiting a single dim is one scan
(idx, count), _ = jax.lax.scan(op, carry, basis)
assert count == n
assert np.allclose(idx, jp.r_[0, 0, 1, 0])
(idx, count), _ = jax.lax.scan(op, (idx,count), basis)
assert count == 2*n, count
assert np.allclose(idx, jp.r_[0, 0, 2, 0])
# one scan per dim
(idx, count), _ = jax.lax.scan(
lambda c,x: jax.lax.scan(op, c, basis),
carry, basis)
assert count == n**2, count
assert np.allclose(idx, jp.r_[0, 1, 0, 0])
# N dim (⌐⊙_⊙)
def rec_scan(dim, op):
assert dim > 0
_op = op if dim == 1 else rec_scan(dim - 1, op)
return lambda c,x: jax.lax.scan(_op, c, basis, unroll=10 if dim <= 2 else 1)
visitor = jax.jit(rec_scan(size, op))
(idx, count), _ = visitor(carry, None)
assert count == n**size, count
assert np.allclose(idx, jp.ones(size)*(n-1))
############################ jax style, flat
def make_op2(n, size):
multiplet_dims = jp.array([n for i in range(size)])
def op(carry, x):
idx, count = carry
# count how many grid points visited
count = count + 1
# next idx is just unravel count w/ shape of multiplet grid
idx = jp.array(jp.unravel_index(count, multiplet_dims))
return (idx, count), None
return op
op2 = make_op2(n, size)
idx2 = jp.zeros((size,), 'i')
carry = idx, 0
visitor2 = jax.jit(lambda c: jax.lax.scan(op2, c, jp.r_[:n**size], unroll=10))
(idx2, count2), _ = visitor2(carry)
assert count2 == n**size, count
assert np.allclose(idx2, jp.ones(size)*(n-1))
# time it
import time
for n in [5, 25]:
for size in [3, 4, 5]:
# generate indices
basis = np.r_[:n]
bases = [basis]*size
tik = time.time()
for i in range(100):
multiplets = np.array(np.meshgrid(*bases)).reshape(size, -1)
tiktok = time.time() - tik
print(f'{n} {size} meshgrid {1e3*tiktok/100:0.3f} ms/')
# scan visitor
visitor = jax.jit(rec_scan(size, make_op(n, size)))
carry = jp.zeros((size, ), 'i'), 0
visitor(carry, None)
tik = time.time()
for i in range(100):
(idx, count), _ = visitor(carry, None)
assert count == multiplets.shape[1], (count, multiplets.shape)
tiktok = time.time() - tik
print(f'{n} {size} jax visitor {1e3*tiktok/100:0.3f} ms/')
# scan visitor flat
op2 = make_op2(n, size)
idx2 = jp.zeros((size,), 'i')
carry = idx, 0
visitor2 = jax.jit(lambda c: jax.lax.scan(op2, c, jp.r_[:n**size], unroll=10))
(idx2, count2), _ = visitor2(carry)
tik = time.time()
for i in range(100):
(idx2, count2), _ = visitor2(carry)
assert count2 == multiplets.shape[1], (count2, multiplets.shape)
tiktok = time.time() - tik
print(f'{n} {size} jax flat {1e3*tiktok/100:0.3f} ms/')
@maedoc
Copy link
Author

maedoc commented Dec 8, 2023

output like so on a 2018 xeon

5 3 meshgrid 0.026 ms/
5 3 jax visitor 0.023 ms/
5 3 jax flat 0.024 ms/
5 4 meshgrid 0.040 ms/
5 4 jax visitor 0.024 ms/
5 4 jax flat 0.026 ms/
5 5 meshgrid 0.053 ms/
5 5 jax visitor 0.026 ms/
5 5 jax flat 0.026 ms/
25 3 meshgrid 0.056 ms/
25 3 jax visitor 0.025 ms/
25 3 jax flat 0.026 ms/
25 4 meshgrid 2.023 ms/
25 4 jax visitor 0.040 ms/
25 4 jax flat 0.083 ms/
25 5 meshgrid 162.546 ms/
25 5 jax visitor 0.488 ms/
25 5 jax flat 1.495 ms/

the flat iteration has marginal impact

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment