Last active
July 7, 2025 07:10
-
-
Save chuanqi129/fe09b086d24d8f46ab65cc466d3e0fb2 to your computer and use it in GitHub Desktop.
Run stable diffusion model on PyTorch XPU
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import os | |
import time | |
import sys | |
import torch | |
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionXLPipeline, StableDiffusion3Pipeline | |
import argparse | |
import numpy as np | |
from PIL import Image | |
import pytorch_fid | |
import requests | |
import csv | |
import logging | |
from datetime import datetime | |
from datasets import load_dataset | |
from torchmetrics.functional.multimodal import clip_score | |
from functools import partial | |
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO').upper() | |
logger = logging.getLogger("stable_diffusion: main.py") | |
logging.basicConfig(level=LOG_LEVEL) | |
PIPELINE_MAPPING = {"CompVis/stable-diffusion-v1-4": StableDiffusionPipeline, | |
"runwayml/stable-diffusion-v1-5": StableDiffusionPipeline, | |
"stabilityai/stable-diffusion-2-1": StableDiffusionPipeline, | |
"SimianLuo/LCM_Dreamshaper_v7": StableDiffusionPipeline, | |
"stabilityai/stable-diffusion-xl-base-1.0": StableDiffusionXLPipeline, | |
"stabilityai/stable-diffusion-3-medium-diffusers": StableDiffusion3Pipeline} | |
parser = argparse.ArgumentParser(description='PyTorch StableDiffusion TexttoImage') | |
parser.add_argument('--prompt', default="nateraw/parti-prompts", type=str, help='prompt_dataset') | |
parser.add_argument('--batch_size', default=1, type=int, help='batch size') | |
parser.add_argument('--idx_start', default=0, type=int, help='select the start index of image') | |
parser.add_argument('--precision', choices=["fp32", "fp16", "bf16"], | |
default="fp16", type=str, help='precision') | |
parser.add_argument('--jit', action='store_true', default=False, help='enable JIT') | |
parser.add_argument('--iteration', default=30, type=int, help='test iterations') | |
parser.add_argument('--warmup_iter', default=2, type=int, help='test warmup') | |
parser.add_argument('--device', default='xpu', type=str, help='cpu, cuda or xpu') | |
parser.add_argument('--save_image', action='store_true', default=False, help='save image') | |
parser.add_argument('--save_tensor', action='store_true', default=False, help='save tensor') | |
parser.add_argument('--accuracy', action='store_true', default=False, help='compare the result with cuda') | |
parser.add_argument('-m', '--model_id', | |
choices=["CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1", "SimianLuo/LCM_Dreamshaper_v7", "stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/stable-diffusion-3-medium-diffusers"], | |
default='stabilityai/stable-diffusion-2-1', type=str, metavar='PATH', | |
help='path to model structure or weight') | |
parser.add_argument('--ref_path', default='', type=str, metavar='PATH', | |
help='path to reference image (default: none)') | |
parser.add_argument('--save_path', default='./xpu_result', type=str, help='output image dir') | |
parser.add_argument('--num_inference_steps', default=50, type=int, help='number of unet step') | |
parser.add_argument("--disable_optimize_transformers", action="store_true") | |
parser.add_argument('--evaluate_method', choices=["clip", "fid"], | |
default="fid", type=str, help='evaluation method, now we suppor clip and fid') | |
parser.add_argument('--pipeline_mode', choices=["img2img", "text2img"], | |
default="text2img", type=str, help='evaluation method, now we suppor clip and fid') | |
parser.add_argument('--channels_last', action='store_true', default=False, help='enable channels_last') | |
parser.add_argument('--height', type=int, default=768, help='height of generated image') | |
parser.add_argument('--width', type=int, default=768, help='width of generated image') | |
parser.add_argument("--gscale", type=float, default=7.5, help="the guidance scale") | |
parser.add_argument("--ipex", action="store_true") | |
parser.add_argument("--profile", action="store_true") | |
parser.add_argument("--inductor", action="store_true") | |
parser.add_argument('--output-csv-path', default='output.csv', type=str, | |
help='path to output CSV file (default: output.csv)') | |
parser.add_argument('--use_pt2e', action='store_true',help='Enable pt2e quantized model in the pipeline') | |
parser.add_argument('--path_quantized_model', type=str, help='quantized model path for sd3/sdxl') | |
parser.add_argument('--path_tokenizer_3', default='./sd3_tokenizer_3', type=str, help='tokenizer 3 path for sd3, if not set, will save in current repo') | |
args = parser.parse_args() | |
print(args) | |
def compare(xpu_res, ref_res): | |
xpu = torch.tensor(xpu_res) | |
ref = torch.tensor(ref_res) | |
diff_value = torch.abs((xpu - ref)) | |
max_diff = torch.max(diff_value) | |
shape = 1 | |
for i in range(xpu.dim()): | |
shape = shape * xpu.shape[i] | |
value = diff_value > 0.1 | |
num = torch.sum(value.contiguous().view(-1)) | |
ratio1 = num / shape | |
logger.warning("difference larger than 0.1, ratio = {}".format(ratio1)) | |
value = diff_value > 0.01 | |
num = torch.sum(value.contiguous().view(-1)) | |
ratio2 = num / shape | |
logger.warning("difference larger than 0.01, ratio = {}".format(ratio2)) | |
value = diff_value > 0.001 | |
num = torch.sum(value.contiguous().view(-1)) | |
ratio3 = num / shape | |
logger.warning("difference larger than 0.001, ratio = {}".format(ratio3)) | |
if ratio1 < 0.01 and ratio2 < 0.08 and ratio3 < 0.4: | |
logger.warning("accuracy pass") | |
else: | |
logger.warning("accuracy fail") | |
def compare_pil_images(ref_res, cur_res): | |
xpu = torch.tensor(np.array(cur_res)) | |
ref = torch.tensor(np.array(ref_res)) | |
diff_value = torch.abs((xpu - ref)) | |
max_diff = torch.max(diff_value) | |
shape = 1 | |
for i in range(xpu.dim()): | |
shape = shape * xpu.shape[i] | |
value = diff_value > 0.1 | |
num = torch.sum(value.contiguous().view(-1)) | |
ratio1 = num / shape | |
print("difference larger than 0.1, ratio = {}".format(ratio1)) | |
value = diff_value > 0.01 | |
num = torch.sum(value.contiguous().view(-1)) | |
ratio2 = num / shape | |
print("difference larger than 0.01, ratio = {}".format(ratio2)) | |
value = diff_value > 0.001 | |
num = torch.sum(value.contiguous().view(-1)) | |
ratio3 = num / shape | |
print("difference larger than 0.001, ratio = {}".format(ratio3)) | |
if ratio1 < 0.01 and ratio2 < 0.08 and ratio3 < 0.4: | |
print("accuracy pass") | |
else: | |
print("accuracy fail") | |
def main(): | |
def calculate_clip_score(images, prompts): | |
images_int = (images * 255).astype("uint8") | |
clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach() | |
return round(float(clip_score), 4) | |
profiling = os.environ.get("PROFILE", "OFF").upper() in ["1", "Y", "ON", "YES", "TRUE"] or args.profile | |
# prompt = ["A painting of a squirrel eating a burger"] | |
prompts_dataset = load_dataset(args.prompt, split="train") | |
seed = 2024 | |
prompts_dataset = prompts_dataset.shuffle(seed=seed) | |
is_arc = False | |
if args.device == "xpu": | |
if args.ipex: | |
import intel_extension_for_pytorch as ipex | |
idx = torch.xpu.current_device() | |
is_arc = torch.xpu.get_device_name(idx)=='Intel(R) Arc(TM) Graphics' | |
generator = torch.Generator(device=args.device).manual_seed(seed) | |
elif args.device == "cuda": | |
generator = torch.Generator(device=args.device).manual_seed(seed) | |
else: | |
generator = torch.Generator(device=args.device) | |
amp_enabled = False | |
if args.precision == "fp32": | |
datatype = torch.float | |
elif args.precision == "fp16": | |
datatype = torch.float16 | |
elif args.precision == "bf16": | |
datatype = torch.bfloat16 | |
else: | |
logger.error("unsupported datatype") | |
sys.exit() | |
if args.pipeline_mode == "img2img": | |
if args.model_id.startswith("stabilityai/stable-diffusion-3"): | |
raise NotImplementedError( | |
f"Image-to-image pipeline is not yet supported for {args.model_id}. " | |
) | |
else: | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(args.model_id, torch_dtype=datatype) | |
else: | |
pipe = PIPELINE_MAPPING[args.model_id].from_pretrained(args.model_id, torch_dtype=datatype) | |
if args.model_id == "stabilityai/stable-diffusion-2-1": | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
if args.model_id == "stabilityai/stable-diffusion-3-medium-diffusers" and args.use_pt2e: | |
# save tokenizer_3 | |
if not os.path.exists(args.path_tokenizer_3): | |
print("save tokenizer 3 to ", args.path_tokenizer_3) | |
pipe.tokenizer_3.save_pretrained(args.path_tokenizer_3) | |
pipe = StableDiffusion3Pipeline.from_pretrained(args.model_id, text_encoder_3=None, tokenizer_3=None, torch_dtype=datatype) | |
pipe = pipe.to(args.device) | |
if args.model_id == "stabilityai/stable-diffusion-3-medium-diffusers" and args.use_pt2e: | |
import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq | |
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq | |
from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer | |
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer | |
from torch.export import export_for_training | |
from torch.ao.quantization.quantize_pt2e import ( | |
_convert_to_reference_decomposed_fx, | |
convert_pt2e, | |
prepare_pt2e, | |
prepare_qat_pt2e, | |
) | |
torch._inductor.config.freezing = True | |
from transformers import T5TokenizerFast | |
print("load quantized text encoder 3 in sd3") | |
loaded_quantized_ep_te3 = torch.export.load(args.path_quantized_model) | |
pipe.text_encoder_3 = loaded_quantized_ep_te3.module() | |
pipe.text_encoder_3 = torch.compile(pipe.text_encoder_3) | |
pipe.text_encoder_3.dtype = datatype | |
pipe.tokenizer_3 = T5TokenizerFast.from_pretrained(args.path_tokenizer_3) | |
print("load quantized model done") | |
if args.device == "xpu" and args.ipex: | |
# Use ipex to optimize | |
if hasattr(pipe, "transformer"): | |
pipe.transformer = torch.xpu.optimize(pipe.transformer.eval(), dtype=datatype, inplace=True) | |
elif hasattr(pipe, "unet"): | |
pipe.unet = torch.xpu.optimize(pipe.unet.eval(), dtype=datatype, inplace=True) | |
pipe.vae = torch.xpu.optimize(pipe.vae.eval(), dtype=datatype, inplace=True) | |
pipe.text_encoder = torch.xpu.optimize(pipe.text_encoder.eval(), dtype=datatype, inplace=True) | |
if hasattr(pipe, "text_encoder_2"): | |
pipe.text_encoder_2 = torch.xpu.optimize(pipe.text_encoder_2.eval(), dtype=datatype, inplace=True) | |
if not args.disable_optimize_transformers and args.precision == "fp16": | |
# optimize with ipex | |
if hasattr(pipe, "transformer"): | |
pipe.transformer = ipex.optimize_transformers(pipe.transformer.eval(), dtype=datatype, device=args.device, inplace=True) | |
elif hasattr(pipe, "unet"): | |
pipe.unet = ipex.optimize_transformers(pipe.unet.eval(), dtype=datatype, device=args.device, inplace=True) | |
print("---- Use ipex optimize_transformers fp16 model.") | |
else: | |
# optimize with ipex | |
print("---- Use ipex optimize model.") | |
if args.channels_last or is_arc: | |
if args.model_id == "stabilityai/stable-diffusion-3-medium-diffusers": | |
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last) | |
pipe.vae = pipe.vae.to(memory_format=torch.channels_last) | |
pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last) | |
pipe.text_encoder_2 = pipe.text_encoder_2.to(memory_format=torch.channels_last) | |
else: | |
pipe.unet = pipe.unet.to(memory_format=torch.channels_last) | |
pipe.vae = pipe.vae.to(memory_format=torch.channels_last) | |
pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last) | |
pipe.text_encoder_2 = pipe.text_encoder_2.to(memory_format=torch.channels_last) | |
if args.evaluate_method == "clip": | |
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16") | |
out_type = "np" | |
clip_score_list = [] | |
else: | |
out_type = "pil" | |
if args.accuracy or args.save_tensor: | |
out_type = "tensor" | |
if args.pipeline_mode == "img2img": | |
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
init_image = Image.open(requests.get(url, stream=True).raw) | |
prompt = "two tigers" | |
total_time = 0 | |
logger.debug("output type is: {}".format(out_type)) | |
# Stable Diffusion 3 uses a new backbone for the image denoising via a Transformer architecture rather than U-Net. | |
if args.inductor: | |
torch._inductor.config.freezing = True | |
torch._inductor.config.force_disable_caches = True | |
if args.model_id == "stabilityai/stable-diffusion-3-medium-diffusers": | |
pipe.transformer = torch.compile(pipe.transformer) | |
pipe.vae.decode = torch.compile(pipe.vae.decode) | |
else: | |
pipe.unet = torch.compile(pipe.unet) | |
pipe.text_encoder = torch.compile(pipe.text_encoder) | |
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2) | |
pipe.vae.decode = torch.compile(pipe.vae.decode) | |
def run_xpu_gen(): | |
if args.pipeline_mode == "img2img": | |
images = pipe(prompt=prompt, image=init_image, generator=generator, num_inference_steps=args.num_inference_steps, guidance_scale=args.gscale, height=args.height, width=args.width, output_type=out_type).images | |
else: | |
images = pipe(input, generator=generator, num_inference_steps=args.num_inference_steps, guidance_scale=args.gscale, height=args.height, width=args.width, output_type=out_type).images | |
return images | |
def write_to_csv(output_data, csv_file_path): | |
file_exists = os.path.isfile(csv_file_path) | |
with open(csv_file_path, mode='a', newline='') as csv_file: | |
writer = csv.DictWriter(csv_file, fieldnames=output_data) | |
if not file_exists: | |
writer.writeheader() | |
writer.writerow(output_data) | |
with torch.no_grad(): | |
for step in range(args.warmup_iter): | |
idx1 = args.idx_start + int(step * args.batch_size) | |
idx2 = args.idx_start + int((step + 1) * args.batch_size) | |
input = prompts_dataset[step]["Prompt"] | |
print(f"input is : {prompt if args.pipeline_mode == 'img2img' else input}") | |
if args.device == "xpu": | |
if args.ipex: | |
with torch.xpu.amp.autocast(enabled=amp_enabled, dtype=datatype): | |
images = run_xpu_gen() | |
else: | |
images = run_xpu_gen() | |
torch.xpu.synchronize() | |
elif args.device == "cuda": | |
images = pipe(input, generator=generator, num_inference_steps=args.num_inference_steps, guidance_scale=args.gscale, height=args.height, width=args.width, output_type=out_type).images | |
torch.cuda.synchronize() | |
else: | |
images = pipe(input, generator=generator, num_inference_steps=args.num_inference_steps, guidance_scale=args.gscale, height=args.height, width=args.width, output_type=out_type).images | |
image_before = [] | |
iter = 0 | |
for iter in range(args.iteration): | |
print("Iteration = {}".format(iter)) | |
step = 0 | |
idx1 = args.idx_start + int(step * args.batch_size) | |
idx2 = args.idx_start + int((step + 1) * args.batch_size) | |
logger.debug("idx1={}".format(idx1)) | |
logger.debug("idx2={}".format(idx2)) | |
input = prompts_dataset[iter]["Prompt"] | |
logger.debug("input is : {}".format(str(prompt) if args.pipeline_mode == "img2img" else str(input))) | |
if args.device == "xpu": | |
with ( | |
contextlib.nullcontext(None) if not profiling else | |
torch.profiler.profile( | |
activities=[torch.profiler.ProfilerActivity.CPU, | |
torch.profiler.ProfilerActivity.XPU], | |
record_shapes=True, | |
) | |
) as prof: | |
try: | |
import memory_check | |
memory_check.display_mem("xpu:0") | |
except: | |
pass | |
start_time = time.time() | |
if args.ipex: | |
with torch.xpu.amp.autocast(enabled=amp_enabled, dtype=datatype): | |
images = run_xpu_gen() | |
else: | |
images = run_xpu_gen() | |
torch.xpu.synchronize() | |
end_time = time.time() | |
if profiling: | |
torch.save(prof.key_averages().table(sort_by="self_xpu_time_total"), 'stale_diffusion_inf_profile.pt') | |
# Cannot sort by id when using kineto | |
# torch.save(prof.table(sort_by="id", row_limit=-1), 'stable_diffusion_inf_profile_detailed.pt') | |
prof.export_chrome_trace('./stable_diffusion_inf_profile_trace.json') | |
elif args.device == "cuda": | |
start_time = time.time() | |
images = pipe(input, generator=generator, num_inference_steps=args.num_inference_steps, guidance_scale=args.gscale, height=args.height, width=args.width, output_type=out_type).images | |
torch.cuda.synchronize() | |
end_time = time.time() | |
else: | |
start_time = time.time() | |
images = pipe(input, generator=generator, num_inference_steps=args.num_inference_steps, guidance_scale=args.gscale, height=args.height, width=args.width, output_type=out_type).images | |
end_time = time.time() | |
iter_time = end_time - start_time | |
total_time += iter_time | |
if args.evaluate_method == "clip": | |
sd_clip_score = calculate_clip_score(images, input) | |
clip_score_list.append(sd_clip_score) | |
else: | |
if args.accuracy: | |
for i in range(args.batch_size): | |
name = "result_{}_{}.png".format(idx1 + i, iter) if args.save_image else "result_{}_{}.pt".format(idx1 + i, iter) | |
name = os.path.join(args.ref_path, name) | |
if args.save_image: | |
ref_image = Image.open(name) | |
compare_pil_images(ref_image, images[i]) | |
else: | |
ref_pt = torch.load(name) | |
compare(ref_pt, images[i]) | |
if not os.path.exists(args.save_path): | |
os.mkdir(args.save_path) | |
if args.save_tensor: | |
for i in range(args.batch_size): | |
file_name = "./result_{}_{}.pt".format(idx1 + i, iter) | |
save_path = os.path.join(args.save_path, file_name) | |
torch.save(images[i], save_path) | |
if args.save_image: | |
for i in range(args.batch_size): | |
file_name = "./result_{}_{}.png".format(idx1 + i, iter) | |
save_path = os.path.join(args.save_path, file_name) | |
images[i].save(save_path) | |
iter += 1 | |
total_sample = args.iteration * args.batch_size | |
latency = total_time / total_sample * 1000 | |
throughput = total_sample / total_time | |
print("inference Latency: {} ms".format(latency)) | |
print("inference Throughput: {} samples/s".format(throughput)) | |
if args.evaluate_method == "clip": | |
print(f"CLIP score: {np.mean(clip_score_list)}") | |
run_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
output_data = { | |
"model_name": args.model_id, | |
"latency": latency, | |
"throughput": throughput, | |
"data_type": args.precision, | |
"woq": args.disable_optimize_transformers, | |
"amp": amp_enabled, | |
"inductor": args.inductor, | |
"ipex": args.ipex, | |
"run_date": run_date | |
} | |
write_to_csv(output_data, args.output_csv_path) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment