Skip to content

Instantly share code, notes, and snippets.

@andfoy
Created September 9, 2022 01:47
Show Gist options
  • Save andfoy/2b52426b4e0bddcb4c264b1f1e9eb4b7 to your computer and use it in GitHub Desktop.
Save andfoy/2b52426b4e0bddcb4c264b1f1e9eb4b7 to your computer and use it in GitHub Desktop.
CuPy random distributions broadcasting support
import numpy as np
import cupy as cp
import inspect
import random
distributions = {}
random_members = dir(cp.random)
for mem_name in random_members:
mem = getattr(cp.random, mem_name)
if not callable(mem):
continue
signature = inspect.signature(mem)
params = signature.parameters
distr_params = []
if 'size' in params:
for p in params:
if p in {'size', 'seed'}:
break
distr_params.append(p)
if distr_params != []:
distributions[mem_name] = distr_params
distributions.pop('choice')
distr_status = {d: (False, None) for d in distributions}
for distr_name in distributions:
distr_args = distributions[distr_name]
num_dims = random.randint(2, 4)
dim_sizes = tuple(random.randint(2, 3) for _ in range(num_dims))
dims = tuple(range(1, num_dims))
perm_dims = [random.randint(1, num_dims - 1)
for _ in range(len(distr_args) - 1)]
perm_dims = [0] + perm_dims
perm_dims = np.random.permutation(perm_dims)
args = []
for dim in perm_dims:
size = dim_sizes
if dim > 0:
size = [1 for _ in range(num_dims)]
size[dim] = dim_sizes[dim]
arg_value = cp.random.rand(*size)
args.append(arg_value)
distr_fn = getattr(cp.random, distr_name)
print(f'Testing {distr_name} {[(d, s.shape) for d, s in zip(distr_args, args)]}')
try:
out = distr_fn(*args)
assert out.shape == dim_sizes
distr_status[distr_name] = (True, None)
except Exception as e:
print('Needs inspection')
distr_status[distr_name] = (False, e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment