Note: related JAX test
func, mlir, jaxpr = None, None, None
def func(sz:int):
o = jnp.ones(sz, jnp.float32)
return o[0]
jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
{ lambda ; a:i64[]. let
b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a
_:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
c:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
d:f32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] b c
in (d,) }
module @test {
func.func public @catalyst.entry_point(%arg0: tensor<i64>) -> tensor<f32> attributes {llvm.emit_c_interface} {
%0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%1 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
%2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.dynamic_broadcast_in_dim %0, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%4 = stablehlo.convert %arg0 : tensor<i64>
%5 = stablehlo.constant dense<0> : tensor<i64>
%6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor<i64>) -> tensor<1xi64>
%7 = "stablehlo.gather"(%3, %6) {dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi64>) -> tensor<f32>
return %7 : tensor<f32>
}
func.func @setup() {
"quantum.init"() : () -> ()
return
}
func.func @teardown() {
"quantum.finalize"() : () -> ()
return
}
}
array(4.68817794e-310)
func, mlir, jaxpr = None, None, None
def func(sz:int, idx:int):
o = jnp.ones((sz,sz), jnp.float32)
return o[idx,0]
jaxpr = jax.make_jaxpr(func)(3,0)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3,0)
{ lambda ; a:i64[] b:i64[]. let
c:f32[a,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 1.0
a a
d:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
e:bool[] = lt b 0
f:i64[] = add b d
g:i64[] = select_n e b f
_:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
h:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
i:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
j:i64[2] = concatenate[dimension=0] h i
k:f32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=True
] c j
in (k,) }
module @test {
func.func public @catalyst.entry_point(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<f32> attributes {llvm.emit_c_interface} {
%0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%1 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
%2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
%4 = stablehlo.reshape %3 : (tensor<i32>) -> tensor<1xi32>
%5 = stablehlo.concatenate %2, %4, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%6 = stablehlo.dynamic_broadcast_in_dim %0, %5, dims = [] : (tensor<f32>, tensor<2xi32>) -> tensor<?x?xf32>
%7 = stablehlo.convert %arg0 : tensor<i64>
%8 = stablehlo.constant dense<0> : tensor<i64>
%9 = stablehlo.compare LT, %arg1, %8, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
%10 = stablehlo.add %arg1, %7 : tensor<i64>
%11 = stablehlo.select %9, %10, %arg1 : tensor<i1>, tensor<i64>
%12 = stablehlo.convert %arg0 : tensor<i64>
%13 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor<i64>) -> tensor<1xi64>
%14 = stablehlo.constant dense<0> : tensor<i64>
%15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor<i64>) -> tensor<1xi64>
%16 = stablehlo.concatenate %13, %15, dim = 0 : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
%17 = "stablehlo.gather"(%6, %16) {dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xi64>) -> tensor<f32>
return %17 : tensor<f32>
}
func.func @setup() {
"quantum.init"() : () -> ()
return
}
func.func @teardown() {
"quantum.finalize"() : () -> ()
return
}
}
array(5.26354425e-315)
func, mlir, jaxpr = None, None, None
def func(sz:int):
o = jnp.ones(sz, jnp.float32)
return o[0:2]
jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
Traceback (most recent call last):
Cell In[72], line 1
jaxpr = jax.make_jaxpr(func)(3)
File /workspace/modules/jax/jax/_src/traceback_util.py:177 in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File /workspace/modules/jax/jax/_src/api.py:2462 in make_jaxpr_f
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
File /workspace/modules/jax/jax/_src/profiler.py:340 in wrapper
return func(*args, **kwargs)
File /workspace/modules/jax/jax/_src/interpreters/partial_eval.py:2239 in trace_to_jaxpr_dynamic2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File /workspace/modules/jax/jax/_src/interpreters/partial_eval.py:2254 in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File /workspace/modules/jax/jax/_src/linear_util.py:191 in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
Cell In[71], line 3 in func
return o[0:2]
File /workspace/modules/jax/jax/_src/numpy/array_methods.py:728 in op
return getattr(self.aval, f"_{name}")(self, *args)
File /workspace/modules/jax/jax/_src/numpy/array_methods.py:341 in _getitem
return lax_numpy._rewriting_take(self, item)
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4323 in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4332 in _gather
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4589 in _index_to_gather
raise IndexError(msg)
IndexError: Cannot use NumPy slice indexing on an array dimension whose size is not statically known (Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>). Try using lax.dynamic_slice/dynamic_update_slice
None
Traceback (most recent call last):
Cell In[74], line 1
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
AttributeError: 'NoneType' object has no attribute 'out_avals'
Traceback (most recent call last):
Cell In[75], line 1
inject_functions(mlir, ctx)
File /workspace/modules/catalyst/frontend/catalyst/utils/gen_mlir.py:58 in inject_functions
module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'
None
Traceback (most recent call last):
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
Cell In[77], line 1
qjit(str(mlir))(3)
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
function = self.compile()
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source
func, mlir, jaxpr = None, None, None
def func(sz:int):
o = jnp.ones((sz,sz), jnp.float32)
return o[0]
jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
{ lambda ; a:i64[]. let
b:f32[a,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 1.0
a a
_:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
c:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
d:f32[a] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
unique_indices=True
] b c
e:f32[a] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None,)] d a
in (e,) }
Traceback (most recent call last):
File /workspace/modules/jax/jax/_src/core.py:744 in __getattr__
attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute 'type'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File /opt/venv/bin/ipython3:8
sys.exit(start_ipython())
File /opt/venv/lib/python3.10/site-packages/IPython/__init__.py:128 in start_ipython
return launch_new_instance(argv=argv, **kwargs)
File /opt/venv/lib/python3.10/site-packages/traitlets/config/application.py:1043 in launch_instance
app.start()
File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ipapp.py:318 in start
self.shell.mainloop()
File /opt/venv/lib/python3.10/site-packages/IPython/terminal/interactiveshell.py:888 in mainloop
self.interact()
File /opt/venv/lib/python3.10/site-packages/IPython/terminal/interactiveshell.py:881 in interact
self.run_cell(code, store_history=True)
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3006 in run_cell
result = self._run_cell(
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3061 in _run_cell
result = runner(coro)
File /opt/venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:129 in _pseudo_sync_runner
coro.send(None)
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3266 in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3445 in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3505 in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
Cell In[82], line 1
jaxpr = jax.make_jaxpr(func)(3)
Cell In[81], line 3 in func
return o[0]
File /workspace/modules/jax/jax/_src/numpy/array_methods.py:728 in op
return getattr(self.aval, f"_{name}")(self, *args)
File /workspace/modules/jax/jax/_src/numpy/array_methods.py:341 in _getitem
return lax_numpy._rewriting_take(self, item)
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4323 in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4350 in _gather
y = lax.gather(
JaxStackTraceBeforeTransformation: AttributeError: DynamicJaxprTracer has no attribute type
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
Cell In[84], line 1
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
File /workspace/modules/catalyst/frontend/catalyst/utils/jax_extras.py:299 in jaxpr_to_mlir
module, context = custom_lower_jaxpr_to_module(
File /workspace/modules/catalyst/frontend/catalyst/utils/jax_extras.py:367 in custom_lower_jaxpr_to_module
lower_jaxpr_to_fun(
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:1216 in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:1433 in jaxpr_subcomp
ans = rule(rule_ctx, *rule_inputs, **eqn.params)
File /workspace/modules/jax/jax/_src/lax/slicing.py:1827 in _gather_lower
slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes)
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:672 in eval_dynamic_shape_as_tensor
return shape_tensor(eval_dynamic_shape(ctx, shape))
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:96 in shape_tensor
ds = map(lower_dim, sizes)
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:93 in lower_dim
if d.type != i32_type:
File /workspace/modules/jax/jax/_src/core.py:746 in __getattr__
raise AttributeError(
AttributeError: DynamicJaxprTracer has no attribute type
Traceback (most recent call last):
Cell In[85], line 1
inject_functions(mlir, ctx)
File /workspace/modules/catalyst/frontend/catalyst/utils/gen_mlir.py:58 in inject_functions
module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'
None
Traceback (most recent call last):
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
Cell In[87], line 1
qjit(str(mlir))(3)
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
function = self.compile()
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source
func, mlir, jaxpr = None, None, None
def func(sz:int):
o = jnp.ones((sz,sz), jnp.float32)
return o[:]
jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
{ lambda ; a:i64[]. let
b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a
c:f32[a] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None,)] b a
in (c,) }
module @test {
func.func public @catalyst.entry_point(%arg0: tensor<i64>) -> tensor<?xf32> attributes {llvm.emit_c_interface} {
%0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%1 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
%2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.dynamic_broadcast_in_dim %0, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%4 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
%5 = stablehlo.reshape %4 : (tensor<i32>) -> tensor<1xi32>
%6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
return %6 : tensor<?xf32>
}
func.func @setup() {
"quantum.init"() : () -> ()
return
}
func.func @teardown() {
"quantum.finalize"() : () -> ()
return
}
}
Traceback (most recent call last):
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:9:10: error: failed to legalize operation 'mhlo.dynamic_broadcast_in_dim'
%6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
^
catalyst_module:9:10: note: see current operation: %17 = "mhlo.dynamic_broadcast_in_dim"(%11, %16) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
Invalid pass with name 'FinalizingBufferize' failed
While processing 'mlir::detail::OpToOpPassAdaptor' pass of the 'BufferizationPass' pipeline
Failed to lower MLIR module
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
Cell In[57], line 1
qjit(str(mlir))(3)
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
function = self.compile()
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:9:10: error: failed to legalize operation 'mhlo.dynamic_broadcast_in_dim'
%6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
^
catalyst_module:9:10: note: see current operation: %17 = "mhlo.dynamic_broadcast_in_dim"(%11, %16) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
Invalid pass with name 'FinalizingBufferize' failed
While processing 'mlir::detail::OpToOpPassAdaptor' pass of the 'BufferizationPass' pipeline
Failed to lower MLIR module
func, mlir, jaxpr = None, None, None
def func(sz:int):
o = jnp.ones((sz,sz,sz), jnp.float32)
return o[0,:,0]
jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
{ lambda ; a:i64[]. let
b:f32[a,a,a] = broadcast_in_dim[
broadcast_dimensions=()
shape=(None, None, None)
] 1.0 a a a
_:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
_:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
c:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
d:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
e:i64[2] = concatenate[dimension=0] c d
f:f32[a] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0, 2), start_index_map=(0, 2))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1)
unique_indices=True
] b e
g:f32[a] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None,)] f a
in (g,) }
Traceback (most recent call last):
File /workspace/modules/jax/jax/_src/core.py:744 in __getattr__
attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute 'type'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File /opt/venv/bin/ipython3:8
sys.exit(start_ipython())
File /opt/venv/lib/python3.10/site-packages/IPython/__init__.py:128 in start_ipython
return launch_new_instance(argv=argv, **kwargs)
File /opt/venv/lib/python3.10/site-packages/traitlets/config/application.py:1043 in launch_instance
app.start()
File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ipapp.py:318 in start
self.shell.mainloop()
File /opt/venv/lib/python3.10/site-packages/IPython/terminal/interactiveshell.py:888 in mainloop
self.interact()
File /opt/venv/lib/python3.10/site-packages/IPython/terminal/interactiveshell.py:881 in interact
self.run_cell(code, store_history=True)
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3006 in run_cell
result = self._run_cell(
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3061 in _run_cell
result = runner(coro)
File /opt/venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:129 in _pseudo_sync_runner
coro.send(None)
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3266 in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3445 in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3505 in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
Cell In[112], line 1
jaxpr = jax.make_jaxpr(func)(3)
Cell In[111], line 3 in func
return o[0,:,0]
File /workspace/modules/jax/jax/_src/numpy/array_methods.py:728 in op
return getattr(self.aval, f"_{name}")(self, *args)
File /workspace/modules/jax/jax/_src/numpy/array_methods.py:341 in _getitem
return lax_numpy._rewriting_take(self, item)
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4323 in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4350 in _gather
y = lax.gather(
JaxStackTraceBeforeTransformation: AttributeError: DynamicJaxprTracer has no attribute type
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
Cell In[114], line 1
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
File /workspace/modules/catalyst/frontend/catalyst/utils/jax_extras.py:299 in jaxpr_to_mlir
module, context = custom_lower_jaxpr_to_module(
File /workspace/modules/catalyst/frontend/catalyst/utils/jax_extras.py:367 in custom_lower_jaxpr_to_module
lower_jaxpr_to_fun(
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:1216 in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:1433 in jaxpr_subcomp
ans = rule(rule_ctx, *rule_inputs, **eqn.params)
File /workspace/modules/jax/jax/_src/lax/slicing.py:1827 in _gather_lower
slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes)
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:672 in eval_dynamic_shape_as_tensor
return shape_tensor(eval_dynamic_shape(ctx, shape))
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:96 in shape_tensor
ds = map(lower_dim, sizes)
File /workspace/modules/jax/jax/_src/interpreters/mlir.py:93 in lower_dim
if d.type != i32_type:
File /workspace/modules/jax/jax/_src/core.py:746 in __getattr__
raise AttributeError(
AttributeError: DynamicJaxprTracer has no attribute type
Traceback (most recent call last):
Cell In[115], line 1
inject_functions(mlir, ctx)
File /workspace/modules/catalyst/frontend/catalyst/utils/gen_mlir.py:58 in inject_functions
module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'
None
Traceback (most recent call last):
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
Cell In[117], line 1
qjit(str(mlir))(3)
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
function = self.compile()
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source
func, mlir, jaxpr = None, None, None
def func(sz:int, idx:int):
o = jnp.ones(sz, jnp.float32)
return o[0:idx]
jaxpr = jax.make_jaxpr(func)(3,0)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3,0)
Traceback (most recent call last):
Cell In[102], line 1
jaxpr = jax.make_jaxpr(func)(3,0)
File /workspace/modules/jax/jax/_src/traceback_util.py:177 in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File /workspace/modules/jax/jax/_src/api.py:2462 in make_jaxpr_f
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
File /workspace/modules/jax/jax/_src/profiler.py:340 in wrapper
return func(*args, **kwargs)
File /workspace/modules/jax/jax/_src/interpreters/partial_eval.py:2239 in trace_to_jaxpr_dynamic2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File /workspace/modules/jax/jax/_src/interpreters/partial_eval.py:2254 in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File /workspace/modules/jax/jax/_src/linear_util.py:191 in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
Cell In[101], line 3 in func
return o[0:idx]
File /workspace/modules/jax/jax/_src/numpy/array_methods.py:728 in op
return getattr(self.aval, f"_{name}")(self, *args)
File /workspace/modules/jax/jax/_src/numpy/array_methods.py:341 in _getitem
return lax_numpy._rewriting_take(self, item)
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4323 in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4332 in _gather
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4584 in _index_to_gather
raise IndexError(msg)
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
None
Traceback (most recent call last):
Cell In[104], line 1
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
AttributeError: 'NoneType' object has no attribute 'out_avals'
Traceback (most recent call last):
Cell In[105], line 1
inject_functions(mlir, ctx)
File /workspace/modules/catalyst/frontend/catalyst/utils/gen_mlir.py:58 in inject_functions
module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'
None
Traceback (most recent call last):
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
Cell In[107], line 1
qjit(str(mlir))(3,0)
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
function = self.compile()
File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source