Last active
September 18, 2025 06:49
-
-
Save ariG23498/1d5550ec14e5086764294c65d6e8f4b7 to your computer and use it in GitHub Desktop.
AoT with CausalLMs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import spaces | |
from transformers import pipeline | |
# init pipeline | |
MODEL_ID = "HuggingFaceTB/SmolLM3-3B" | |
pipe = pipeline( | |
task="text-generation", | |
model=MODEL_ID, | |
use_cache=False # Note: This is important to bypass the `DynamicCache` not being registerd issue | |
) | |
# build inputs | |
messages = [ | |
{"role": "user", "content": "Who are you?"}, | |
] | |
# taken from: https://huggingface.co/blog/zerogpu-aoti#5-wrapping-it-all-together | |
def compile_transformer(): | |
with spaces.aoti_capture(pipe.model) as call: | |
pipe(messages) | |
exported = torch.export.export( | |
pipe.model, | |
args=call.args, | |
kwargs=call.kwargs, | |
) | |
return spaces.aoti_compile(exported) | |
compiled_transformer = compile_transformer() | |
spaces.aoti_apply(compiled_transformer, pipe.model) | |
# Traceback (most recent call last): | |
# File "/home/aritra/git-repos/distributed/aot_script.py", line 24, in <module> | |
# compiled_transformer = compile_transformer() | |
# ^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/aot_script.py", line 22, in compile_transformer | |
# return spaces.aoti_compile(exported) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/spaces/__init__.py", line 33, in _aoti_compile | |
# return aoti_compile(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/spaces/zero/torch/aoti.py", line 68, in aoti_compile | |
# artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs) # pyright: ignore [reportArgumentType] | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/__init__.py", line 310, in aot_compile | |
# return compile_fx_aot( | |
# ^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1941, in compile_fx_aot | |
# compiled_artifacts = compile_fx( | |
# ^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2413, in compile_fx | |
# return compile_fx( | |
# ^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2474, in compile_fx | |
# return compile_fx( | |
# ^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2619, in compile_fx | |
# gm, graph_signature = aot_export_module( | |
# ^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1445, in aot_export_module | |
# fx_g, metadata, in_spec, out_spec = _aot_export_function( | |
# ^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1695, in _aot_export_function | |
# aot_state = create_aot_state( | |
# ^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 567, in create_aot_state | |
# fw_metadata = run_functionalized_fw_and_collect_metadata( | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 207, in inner | |
# flat_f_outs = f(*flat_f_args) | |
# ^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn | |
# tree_out = fn(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1350, in functional_call | |
# out = PropagateUnbackedSymInts(mod).run( | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run | |
# self.env[node] = self.run_node(node) | |
# ^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7875, in run_node | |
# result = super().run_node(n) | |
# ^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 256, in run_node | |
# return getattr(self, n.op)(n.target, args, kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 336, in call_function | |
# return target(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_ops.py", line 841, in __call__ | |
# return self._op(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# RuntimeError: Batching rule not implemented for aten::_assert_tensor_metadata. The fallback path does not support operations with no returns. | |
# While executing %_assert_tensor_metadata_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%le,), kwargs = {dtype: torch.bool, device: cuda:0, layout: torch.strided}) | |
# Original traceback: | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/transformers/models/smollm3/modeling_smollm3.py", line 479, in forward | |
# outputs: BaseModelOutputWithPast = self.model( | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/transformers/models/smollm3/modeling_smollm3.py", line 402, in forward | |
# "full_attention": create_causal_mask(**mask_kwargs), | |
# Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Package Information | |
uv pip show transformers | |
Name: transformers | |
Version: 4.56.1 | |
Location: /home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages | |
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm | |
Required-by: | |
uv pip show torch | |
Name: torch | |
Version: 2.10.0.dev20250916+cu128 | |
Location: /home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages | |
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-cufile-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-cusparselt-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvshmem-cu12, nvidia-nvtx-cu12, pytorch-triton, setuptools, sympy, typing-extensions | |
Required-by: accelerate, torchvision |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
model_id = "HuggingFaceTB/SmolLM3-3B-Base" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, device_map="cuda:0", use_cache=False | |
).eval() | |
inputs = tokenizer("How are you? ", return_tensors="pt").to("cuda") | |
example_inputs = (inputs["input_ids"], inputs["attention_mask"]) | |
with torch.inference_mode(): | |
ep = torch.export.export(model, example_inputs) | |
outputs = ep.module()(inputs["input_ids"], inputs["attention_mask"]) | |
# Traceback (most recent call last): | |
# File "/home/aritra/git-repos/distributed/check_export.py", line 29, in <module> | |
# outputs = ep.module()(inputs["input_ids"], inputs["attention_mask"]) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped | |
# return self._wrapped_call(self, *args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 413, in __call__ | |
# raise e | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 400, in __call__ | |
# return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl | |
# return self._call_impl(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl | |
# return inner() | |
# ^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1829, in inner | |
# result = forward_call(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "<eval_with_key>.27", line 360, in forward | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_ops.py", line 841, in __call__ | |
# return self._op(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# RuntimeError: Batching rule not implemented for aten::_assert_tensor_metadata. The fallback path does not support operations with no returns. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# init the model and tok | |
model_id = "HuggingFaceTB/SmolLM3-3B" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, device_map="cuda:0" | |
).eval() | |
# build the inputs | |
messages = [ | |
{"role": "user", "content": "Who are you?"}, | |
] | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(model.device) | |
# inputs required for exporting | |
example_inputs = (inputs["input_ids"], inputs["attention_mask"]) | |
# use inference mode to be sure of no backward graphs | |
with torch.inference_mode(): | |
ep = torch.export.export(model, example_inputs) | |
# use the exported model's forward call | |
outputs = ep.module()(inputs["input_ids"], inputs["attention_mask"]) | |
# Traceback (most recent call last): | |
# File "/home/aritra/git-repos/distributed/check_export.py", line 27, in <module> | |
# outputs = ep.module()(inputs["input_ids"], inputs["attention_mask"]) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped | |
# return self._wrapped_call(self, *args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 413, in __call__ | |
# raise e | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 400, in __call__ | |
# return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl | |
# return self._call_impl(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl | |
# return inner() | |
# ^^^^^^^ | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1829, in inner | |
# result = forward_call(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# File "<eval_with_key>.27", line 360, in forward | |
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_ops.py", line 841, in __call__ | |
# return self._op(*args, **kwargs) | |
# ^^^^^^^^^^^^^^^^^^^^^^^^^ | |
# RuntimeError: Batching rule not implemented for aten::_assert_tensor_metadata. The fallback path does not support operations with no returns. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment