Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Last active September 18, 2025 06:49
Show Gist options
  • Save ariG23498/1d5550ec14e5086764294c65d6e8f4b7 to your computer and use it in GitHub Desktop.
Save ariG23498/1d5550ec14e5086764294c65d6e8f4b7 to your computer and use it in GitHub Desktop.
AoT with CausalLMs
import torch
import spaces
from transformers import pipeline
# init pipeline
MODEL_ID = "HuggingFaceTB/SmolLM3-3B"
pipe = pipeline(
task="text-generation",
model=MODEL_ID,
use_cache=False # Note: This is important to bypass the `DynamicCache` not being registerd issue
)
# build inputs
messages = [
{"role": "user", "content": "Who are you?"},
]
# taken from: https://huggingface.co/blog/zerogpu-aoti#5-wrapping-it-all-together
def compile_transformer():
with spaces.aoti_capture(pipe.model) as call:
pipe(messages)
exported = torch.export.export(
pipe.model,
args=call.args,
kwargs=call.kwargs,
)
return spaces.aoti_compile(exported)
compiled_transformer = compile_transformer()
spaces.aoti_apply(compiled_transformer, pipe.model)
# Traceback (most recent call last):
# File "/home/aritra/git-repos/distributed/aot_script.py", line 24, in <module>
# compiled_transformer = compile_transformer()
# ^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/aot_script.py", line 22, in compile_transformer
# return spaces.aoti_compile(exported)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/spaces/__init__.py", line 33, in _aoti_compile
# return aoti_compile(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/spaces/zero/torch/aoti.py", line 68, in aoti_compile
# artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs) # pyright: ignore [reportArgumentType]
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/__init__.py", line 310, in aot_compile
# return compile_fx_aot(
# ^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1941, in compile_fx_aot
# compiled_artifacts = compile_fx(
# ^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2413, in compile_fx
# return compile_fx(
# ^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2474, in compile_fx
# return compile_fx(
# ^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2619, in compile_fx
# gm, graph_signature = aot_export_module(
# ^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1445, in aot_export_module
# fx_g, metadata, in_spec, out_spec = _aot_export_function(
# ^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1695, in _aot_export_function
# aot_state = create_aot_state(
# ^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 567, in create_aot_state
# fw_metadata = run_functionalized_fw_and_collect_metadata(
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 207, in inner
# flat_f_outs = f(*flat_f_args)
# ^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
# tree_out = fn(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1350, in functional_call
# out = PropagateUnbackedSymInts(mod).run(
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run
# self.env[node] = self.run_node(node)
# ^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7875, in run_node
# result = super().run_node(n)
# ^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 256, in run_node
# return getattr(self, n.op)(n.target, args, kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 336, in call_function
# return target(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_ops.py", line 841, in __call__
# return self._op(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: Batching rule not implemented for aten::_assert_tensor_metadata. The fallback path does not support operations with no returns.
# While executing %_assert_tensor_metadata_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%le,), kwargs = {dtype: torch.bool, device: cuda:0, layout: torch.strided})
# Original traceback:
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/transformers/models/smollm3/modeling_smollm3.py", line 479, in forward
# outputs: BaseModelOutputWithPast = self.model(
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/transformers/models/smollm3/modeling_smollm3.py", line 402, in forward
# "full_attention": create_causal_mask(**mask_kwargs),
# Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
# Package Information
uv pip show transformers
Name: transformers
Version: 4.56.1
Location: /home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by:
uv pip show torch
Name: torch
Version: 2.10.0.dev20250916+cu128
Location: /home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-cufile-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-cusparselt-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvshmem-cu12, nvidia-nvtx-cu12, pytorch-triton, setuptools, sympy, typing-extensions
Required-by: accelerate, torchvision
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "HuggingFaceTB/SmolLM3-3B-Base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda:0", use_cache=False
).eval()
inputs = tokenizer("How are you? ", return_tensors="pt").to("cuda")
example_inputs = (inputs["input_ids"], inputs["attention_mask"])
with torch.inference_mode():
ep = torch.export.export(model, example_inputs)
outputs = ep.module()(inputs["input_ids"], inputs["attention_mask"])
# Traceback (most recent call last):
# File "/home/aritra/git-repos/distributed/check_export.py", line 29, in <module>
# outputs = ep.module()(inputs["input_ids"], inputs["attention_mask"])
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped
# return self._wrapped_call(self, *args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 413, in __call__
# raise e
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 400, in __call__
# return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
# return self._call_impl(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
# return inner()
# ^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1829, in inner
# result = forward_call(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "<eval_with_key>.27", line 360, in forward
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_ops.py", line 841, in __call__
# return self._op(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: Batching rule not implemented for aten::_assert_tensor_metadata. The fallback path does not support operations with no returns.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# init the model and tok
model_id = "HuggingFaceTB/SmolLM3-3B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda:0"
).eval()
# build the inputs
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
# inputs required for exporting
example_inputs = (inputs["input_ids"], inputs["attention_mask"])
# use inference mode to be sure of no backward graphs
with torch.inference_mode():
ep = torch.export.export(model, example_inputs)
# use the exported model's forward call
outputs = ep.module()(inputs["input_ids"], inputs["attention_mask"])
# Traceback (most recent call last):
# File "/home/aritra/git-repos/distributed/check_export.py", line 27, in <module>
# outputs = ep.module()(inputs["input_ids"], inputs["attention_mask"])
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped
# return self._wrapped_call(self, *args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 413, in __call__
# raise e
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 400, in __call__
# return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
# return self._call_impl(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
# return inner()
# ^^^^^^^
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1829, in inner
# result = forward_call(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "<eval_with_key>.27", line 360, in forward
# File "/home/aritra/git-repos/distributed/.venv/lib/python3.12/site-packages/torch/_ops.py", line 841, in __call__
# return self._op(*args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: Batching rule not implemented for aten::_assert_tensor_metadata. The fallback path does not support operations with no returns.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment