|
import argparse |
|
|
|
import torch |
|
import torch_tensorrt |
|
import torch.utils.benchmark as benchmark |
|
|
|
from diffusers import DiffusionPipeline |
|
|
|
CKPT = "stabilityai/stable-diffusion-xl-base-1.0" |
|
PROMPT = "a majestic castle in the clouds" |
|
|
|
|
|
def load_pipeline(run_compile=False, with_tensorrt=False): |
|
pipe = DiffusionPipeline.from_pretrained( |
|
CKPT, torch_dtype=torch.float16, use_safetensors=True |
|
) |
|
pipe = pipe.to("cuda") |
|
|
|
if run_compile and not with_tensorrt: |
|
print("Run torch compile") |
|
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) |
|
elif run_compile and with_tensorrt: |
|
print("Run torch compile with TensorRT backend") |
|
pipe.unet = torch.compile( |
|
pipe.unet, |
|
backend="torch_tensorrt", |
|
options={ |
|
"truncate_long_and_double": True, |
|
"precision": torch.float16, |
|
}, |
|
dynamic=False, |
|
) |
|
|
|
pipe.set_progress_bar_config(disable=True) |
|
return pipe |
|
|
|
|
|
def run_inference(pipe, batch_size=1): |
|
_ = pipe(PROMPT) |
|
|
|
|
|
def benchmark_fn(f, *args, **kwargs): |
|
t0 = benchmark.Timer( |
|
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
|
) |
|
return t0.blocked_autorange().mean * 1e6 |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--batch_size", type=int, default=1) |
|
parser.add_argument("--run_compile", action="store_true") |
|
parser.add_argument("--with_tensorrt", action="store_true") |
|
args = parser.parse_args() |
|
|
|
pipeline = load_pipeline( |
|
run_compile=args.run_compile, with_tensorrt=args.with_tensorrt |
|
) |
|
time = benchmark_fn(run_inference, pipeline, args.batch_size) |
|
print( |
|
f"With compilation: {args.run_compile}, and TensorRT: {args.with_tensorrt} in {time/1e6:.3f} seconds" |
|
) |