FA3 attention processor comes from https://gist.github.com/sayakpaul/ff715f979793d4d44beb68e5e08ee067.
Results from an H100:
latency=36.606 seconds. (AoT regional compilation)
latency=36.555 seconds. (JiT regional compilation)
FA3 attention processor comes from https://gist.github.com/sayakpaul/ff715f979793d4d44beb68e5e08ee067.
Results from an H100:
latency=36.606 seconds. (AoT regional compilation)
latency=36.555 seconds. (JiT regional compilation)
from diffusers import DiffusionPipeline | |
from torch.utils import benchmark | |
from diffusers.utils import load_image | |
import spaces | |
import time | |
import torch | |
import argparse | |
from functools import partial | |
from fa3_processor import QwenDoubleStreamAttnProcessorFA3 | |
TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length') | |
TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length') | |
TRANSFORMER_DYNAMIC_SHAPES = { | |
'hidden_states': { | |
1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM, | |
}, | |
'encoder_hidden_states': { | |
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, | |
}, | |
'encoder_hidden_states_mask': { | |
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, | |
}, | |
'image_rotary_emb': ({ | |
0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM, | |
}, { | |
0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, | |
}), | |
} | |
def benchmark_fn(func_to_benchmark): | |
t0 = benchmark.Timer( | |
stmt="func_to_benchmark()", | |
globals={"func_to_benchmark": func_to_benchmark}, | |
num_threads=torch.get_num_threads(), | |
) | |
return f"{(t0.blocked_autorange().mean):.3f}" | |
@torch.no_grad() | |
def load_pipeline(args): | |
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) | |
pipe.to("cuda") | |
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) | |
pipe.set_progress_bar_config(disable=True) | |
if args.jit_regional: | |
pipe.transformer.compile_repeated_blocks(fullgraph=True) | |
elif args.aot_regional: | |
from torch.utils._pytree import tree_map | |
assert pipe.transformer._repeated_blocks | |
start_time = time.time() | |
pipe_kwargs = get_pipe_kwargs(num_inference_steps=2) | |
# TODO: Find a better way to derive `transformer_blocks`. | |
with spaces.aoti_capture(pipe.transformer.transformer_blocks[0]) as call: | |
pipe(**pipe_kwargs) | |
print(f"{call.kwargs.keys()=}") | |
# Compile | |
dynamic_shapes = tree_map(lambda t: None, call.kwargs) | |
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES | |
path = torch._inductor.aoti_compile_and_package( | |
torch.export.export( | |
# TODO: Find a better way to derive `transformer_blocks`. | |
pipe.transformer.transformer_blocks[0], | |
args=call.args, | |
kwargs=call.kwargs, | |
dynamic_shapes=dynamic_shapes, | |
), | |
inductor_configs={ | |
# compile artifact w/o saving weights in the artifact | |
"aot_inductor.package_constants_in_so": False, | |
} | |
) | |
# Load | |
# TODO: Find a better way to derive `transformer_blocks`. | |
for block in pipe.transformer.transformer_blocks: | |
compiled_transformer_layer = torch._inductor.aoti_load_package(path) | |
compiled_transformer_layer.load_constants( | |
block.state_dict(), check_full_update=True, user_managed=True | |
) | |
block.forward = compiled_transformer_layer | |
end_time = time.time() | |
print(f"AoT regional compilation took: {(end_time-start_time):.3f} seconds.") | |
return pipe | |
def get_pipe_kwargs(true_cfg_scale=4.0, num_inference_steps=50): | |
image = load_image( | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" | |
).convert("RGB") | |
prompt = ( | |
"Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" | |
) | |
pipe_kwargs = { | |
"image": image, | |
"prompt": prompt, | |
"negative_prompt": " ", | |
"true_cfg_scale": true_cfg_scale, | |
"num_inference_steps": num_inference_steps, | |
"generator": torch.manual_seed(0), | |
} | |
return pipe_kwargs | |
def run_inference(pipe, kwargs): | |
_ = pipe(**kwargs) | |
def main(args): | |
pipe = load_pipeline(args) | |
for _ in range(3): | |
pipe(**get_pipe_kwargs(num_inference_steps=5)) | |
inference_func = partial(run_inference, pipe, get_pipe_kwargs()) | |
latency = float(benchmark_fn(inference_func)) | |
print(f"{latency=} seconds.") | |
image = pipe(**get_pipe_kwargs()).images[0] | |
fileprefix = f"aot@{args.aot_regional}" | |
image.save(f"{fileprefix}.png") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--aot_regional", action="store_true") | |
parser.add_argument("--jit_regional", action="store_true") | |
args = parser.parse_args() | |
main(args) |