Skip to content

Instantly share code, notes, and snippets.

@bayerj
Last active May 14, 2020 09:45
Show Gist options
  • Save bayerj/509858e15aa998c43ddb40b33c56dbbc to your computer and use it in GitHub Desktop.
Save bayerj/509858e15aa998c43ddb40b33c56dbbc to your computer and use it in GitHub Desktop.
import numpy as onp
import jax.numpy as jnp
class ArrayContainer:
def __init__(self, value):
self.value = value
def __array__(self):
return self.value
container = ArrayContainer(onp.zeros((2)))
print(onp.concatenate([container, container], -1))
jax_container = ArrayContainer(jnp.zeros((2)))
print(jnp.concatenate([jax_container, jax_container], -1))
In [20]: %edit example_array_magic_method.py
Editing... done. Executing edited code...
[0. 0. 0. 0.]
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/.virtualenvs/jaxvrssm2/lib/python3.7/site-packages/numpy/core/fromnumeric.py in ndim(a)
3069 try:
-> 3070 return a.ndim
3071 except AttributeError:
AttributeError: 'ArrayContainer' object has no attribute 'ndim'
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
~/tmp/example_array_magic_method.py in <module>
14
15 jax_container = ArrayContainer(jnp.zeros((2)))
---> 16 print(jnp.concatenate([jax_container, jax_container], -1))
~/.virtualenvs/jaxvrssm2/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in concatenate(arrays, axis)
1883 if not len(arrays):
1884 raise ValueError("Need at least one array to concatenate.")
-> 1885 if ndim(arrays[0]) == 0:
1886 raise ValueError("Zero-dimensional arrays cannot be concatenated.")
1887 axis = _canonicalize_axis(axis, ndim(arrays[0]))
<__array_function__ internals> in ndim(*args, **kwargs)
~/.virtualenvs/jaxvrssm2/lib/python3.7/site-packages/numpy/core/fromnumeric.py in ndim(a)
3070 return a.ndim
3071 except AttributeError:
-> 3072 return asarray(a).ndim
3073
3074
~/.virtualenvs/jaxvrssm2/lib/python3.7/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
83
84 """
---> 85 return array(a, dtype, copy=False, order=order)
86
87
ValueError: object __array__ method not producing an array
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment