FA3 attention processor comes from https://gist.github.com/sayakpaul/ff715f979793d4d44beb68e5e08ee067.
Results from an H100:
latency=36.606 seconds. (AoT regional compilation)
latency=36.555 seconds. (JiT regional compilation)FA3 attention processor comes from https://gist.github.com/sayakpaul/ff715f979793d4d44beb68e5e08ee067.
Results from an H100:
latency=36.606 seconds. (AoT regional compilation)
latency=36.555 seconds. (JiT regional compilation)| from diffusers import DiffusionPipeline | |
| from torch.utils import benchmark | |
| from diffusers.utils import load_image | |
| import spaces | |
| import time | |
| import torch | |
| import argparse | |
| from functools import partial | |
| from fa3_processor import QwenDoubleStreamAttnProcessorFA3 | |
| TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length') | |
| TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length') | |
| TRANSFORMER_DYNAMIC_SHAPES = { | |
| 'hidden_states': { | |
| 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM, | |
| }, | |
| 'encoder_hidden_states': { | |
| 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, | |
| }, | |
| 'encoder_hidden_states_mask': { | |
| 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, | |
| }, | |
| 'image_rotary_emb': ({ | |
| 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM, | |
| }, { | |
| 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, | |
| }), | |
| } | |
| def benchmark_fn(func_to_benchmark): | |
| t0 = benchmark.Timer( | |
| stmt="func_to_benchmark()", | |
| globals={"func_to_benchmark": func_to_benchmark}, | |
| num_threads=torch.get_num_threads(), | |
| ) | |
| return f"{(t0.blocked_autorange().mean):.3f}" | |
| @torch.no_grad() | |
| def load_pipeline(args): | |
| pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) | |
| pipe.to("cuda") | |
| pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) | |
| pipe.set_progress_bar_config(disable=True) | |
| if args.jit_regional: | |
| pipe.transformer.compile_repeated_blocks(fullgraph=True) | |
| elif args.aot_regional: | |
| from torch.utils._pytree import tree_map | |
| assert pipe.transformer._repeated_blocks | |
| start_time = time.time() | |
| pipe_kwargs = get_pipe_kwargs(num_inference_steps=2) | |
| # TODO: Find a better way to derive `transformer_blocks`. | |
| with spaces.aoti_capture(pipe.transformer.transformer_blocks[0]) as call: | |
| pipe(**pipe_kwargs) | |
| print(f"{call.kwargs.keys()=}") | |
| # Compile | |
| dynamic_shapes = tree_map(lambda t: None, call.kwargs) | |
| dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES | |
| path = torch._inductor.aoti_compile_and_package( | |
| torch.export.export( | |
| # TODO: Find a better way to derive `transformer_blocks`. | |
| pipe.transformer.transformer_blocks[0], | |
| args=call.args, | |
| kwargs=call.kwargs, | |
| dynamic_shapes=dynamic_shapes, | |
| ), | |
| inductor_configs={ | |
| # compile artifact w/o saving weights in the artifact | |
| "aot_inductor.package_constants_in_so": False, | |
| } | |
| ) | |
| # Load | |
| # TODO: Find a better way to derive `transformer_blocks`. | |
| for block in pipe.transformer.transformer_blocks: | |
| compiled_transformer_layer = torch._inductor.aoti_load_package(path) | |
| compiled_transformer_layer.load_constants( | |
| block.state_dict(), check_full_update=True, user_managed=True | |
| ) | |
| block.forward = compiled_transformer_layer | |
| end_time = time.time() | |
| print(f"AoT regional compilation took: {(end_time-start_time):.3f} seconds.") | |
| return pipe | |
| def get_pipe_kwargs(true_cfg_scale=4.0, num_inference_steps=50): | |
| image = load_image( | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" | |
| ).convert("RGB") | |
| prompt = ( | |
| "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" | |
| ) | |
| pipe_kwargs = { | |
| "image": image, | |
| "prompt": prompt, | |
| "negative_prompt": " ", | |
| "true_cfg_scale": true_cfg_scale, | |
| "num_inference_steps": num_inference_steps, | |
| "generator": torch.manual_seed(0), | |
| } | |
| return pipe_kwargs | |
| def run_inference(pipe, kwargs): | |
| _ = pipe(**kwargs) | |
| def main(args): | |
| pipe = load_pipeline(args) | |
| for _ in range(3): | |
| pipe(**get_pipe_kwargs(num_inference_steps=5)) | |
| inference_func = partial(run_inference, pipe, get_pipe_kwargs()) | |
| latency = float(benchmark_fn(inference_func)) | |
| print(f"{latency=} seconds.") | |
| image = pipe(**get_pipe_kwargs()).images[0] | |
| fileprefix = f"aot@{args.aot_regional}" | |
| image.save(f"{fileprefix}.png") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--aot_regional", action="store_true") | |
| parser.add_argument("--jit_regional", action="store_true") | |
| args = parser.parse_args() | |
| main(args) |