Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active September 5, 2025 02:22
Show Gist options
  • Save sayakpaul/48676d257b539e79d50eafedabdc7f95 to your computer and use it in GitHub Desktop.
Save sayakpaul/48676d257b539e79d50eafedabdc7f95 to your computer and use it in GitHub Desktop.
Regional compilation in AoT
from diffusers import DiffusionPipeline
from torch.utils import benchmark
from diffusers.utils import load_image
import spaces
import time
import torch
import argparse
from functools import partial
from fa3_processor import QwenDoubleStreamAttnProcessorFA3
TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {
1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
},
'encoder_hidden_states': {
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
},
'encoder_hidden_states_mask': {
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
},
'image_rotary_emb': ({
0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
}, {
0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
}),
}
def benchmark_fn(func_to_benchmark):
t0 = benchmark.Timer(
stmt="func_to_benchmark()",
globals={"func_to_benchmark": func_to_benchmark},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
@torch.no_grad()
def load_pipeline(args):
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
pipe.set_progress_bar_config(disable=True)
if args.jit_regional:
pipe.transformer.compile_repeated_blocks(fullgraph=True)
elif args.aot_regional:
from torch.utils._pytree import tree_map
assert pipe.transformer._repeated_blocks
start_time = time.time()
pipe_kwargs = get_pipe_kwargs(num_inference_steps=2)
# TODO: Find a better way to derive `transformer_blocks`.
with spaces.aoti_capture(pipe.transformer.transformer_blocks[0]) as call:
pipe(**pipe_kwargs)
print(f"{call.kwargs.keys()=}")
# Compile
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
path = torch._inductor.aoti_compile_and_package(
torch.export.export(
# TODO: Find a better way to derive `transformer_blocks`.
pipe.transformer.transformer_blocks[0],
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
),
inductor_configs={
# compile artifact w/o saving weights in the artifact
"aot_inductor.package_constants_in_so": False,
}
)
# Load
# TODO: Find a better way to derive `transformer_blocks`.
for block in pipe.transformer.transformer_blocks:
compiled_transformer_layer = torch._inductor.aoti_load_package(path)
compiled_transformer_layer.load_constants(
block.state_dict(), check_full_update=True, user_managed=True
)
block.forward = compiled_transformer_layer
end_time = time.time()
print(f"AoT regional compilation took: {(end_time-start_time):.3f} seconds.")
return pipe
def get_pipe_kwargs(true_cfg_scale=4.0, num_inference_steps=50):
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
).convert("RGB")
prompt = (
"Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
)
pipe_kwargs = {
"image": image,
"prompt": prompt,
"negative_prompt": " ",
"true_cfg_scale": true_cfg_scale,
"num_inference_steps": num_inference_steps,
"generator": torch.manual_seed(0),
}
return pipe_kwargs
def run_inference(pipe, kwargs):
_ = pipe(**kwargs)
def main(args):
pipe = load_pipeline(args)
for _ in range(3):
pipe(**get_pipe_kwargs(num_inference_steps=5))
inference_func = partial(run_inference, pipe, get_pipe_kwargs())
latency = float(benchmark_fn(inference_func))
print(f"{latency=} seconds.")
image = pipe(**get_pipe_kwargs()).images[0]
fileprefix = f"aot@{args.aot_regional}"
image.save(f"{fileprefix}.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--aot_regional", action="store_true")
parser.add_argument("--jit_regional", action="store_true")
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment