Created
November 10, 2023 10:14
-
-
Save sergei-mironov/aeebdf7cdeaf7c600ee22b1e0229e621 to your computer and use it in GitHub Desktop.
jax-0.4.19-dynshape.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py | |
index d168d22ab..798d07a6d 100644 | |
--- a/jax/_src/interpreters/mlir.py | |
+++ b/jax/_src/interpreters/mlir.py | |
@@ -1461,7 +1461,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, | |
ans, "lowering function returned a bad output", eqn) | |
assert len(ans) == len(eqn.outvars), (ans, eqn) | |
map(write, eqn.outvars, out_nodes) | |
- core.clean_up_dead_vars(eqn, env, last_used) | |
+ # core.clean_up_dead_vars(eqn, env, last_used) | |
return map(read, jaxpr.outvars), tokens | |
# See docstring for lower_multi_platform. | |
diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py | |
index 2e3ac7005..82e86080f 100644 | |
--- a/jax/_src/lax/slicing.py | |
+++ b/jax/_src/lax/slicing.py | |
@@ -1534,15 +1534,15 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, | |
f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" | |
f"{collapsed_slice_dims}.") | |
- for i in range(len(slice_sizes)): | |
- slice_size = slice_sizes[i] | |
- corresponding_input_size = operand.shape[i] | |
- | |
- if not (slice_size >= 0 and | |
- corresponding_input_size >= slice_size): | |
- raise TypeError(f"Slice size at index {i} in gather op is out of range, " | |
- f"must be within [0, {corresponding_input_size} + 1), " | |
- f"got {slice_size}.") | |
+ # for i in range(len(slice_sizes)): | |
+ # slice_size = slice_sizes[i] | |
+ # corresponding_input_size = operand.shape[i] | |
+ | |
+ # if not (slice_size >= 0 and | |
+ # corresponding_input_size >= slice_size): | |
+ # raise TypeError(f"Slice size at index {i} in gather op is out of range, " | |
+ # f"must be within [0, {corresponding_input_size} + 1), " | |
+ # f"got {slice_size}.") | |
for i in range(len(collapsed_slice_dims)): | |
bound = slice_sizes[collapsed_slice_dims[i]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment