Skip to content

Instantly share code, notes, and snippets.

@ShadowPower
Created November 13, 2022 13:26
Show Gist options
  • Save ShadowPower/1632b77626f863c860130ec4cddf20d5 to your computer and use it in GitHub Desktop.
Save ShadowPower/1632b77626f863c860130ec4cddf20d5 to your computer and use it in GitHub Desktop.
stable diffusion onnx exporter for bes-dev
# Use https://github.com/harishanand95/diffusers to export
from pathlib import Path
import torch
import torch.nn as nn
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler, DiffusionPipeline
from transformers import CLIPTextModel
class VQEncoder(nn.Module):
def __init__(self, stable_diffusion: DiffusionPipeline):
super().__init__()
self.vae = stable_diffusion.vae
def forward(self, x):
latent = self.vae.encode(x)
return latent.parameters
class VQDecoder(nn.Module):
def __init__(self, stable_diffusion: DiffusionPipeline):
super().__init__()
self.vae = stable_diffusion.vae
def forward(self, h):
return self.vae.decode(h * 5.489980697631836)
text_encoder = CLIPTextModel.from_pretrained("weights/stable-diffusion-v1-4/text_encoder", return_dict=False)
lms = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
pipe = StableDiffusionPipeline.from_pretrained("weights/stable-diffusion-v1-4", scheduler=lms, use_auth_token=True)
def convert_to_onnx(unet, post_quant_conv, decoder, text_encoder, height=512, width=512):
p = Path("onnx/")
p.mkdir(parents=True, exist_ok=True)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
h, w = height // 8, width // 8
# unet onnx export
check_inputs = [(torch.rand(2, 4, h, w), torch.tensor(980, dtype=torch.float64), torch.rand(2, 77, 768)),
(torch.rand(2, 4, h, w), torch.tensor(910, dtype=torch.float64), torch.rand(2, 12, 768)),
# batch change, text embed with no trunc
]
traced_model = torch.jit.trace(unet, check_inputs[0], check_inputs=[check_inputs[0]], strict=True)
torch.onnx.export(traced_model, check_inputs[0], "onnx/unet.onnx",
input_names=["latent_model_input", "t", "encoder_hidden_states"],
opset_version=16)
h, w = 64, 64
vae_encoder = VQEncoder(pipe)
vae_decoder = VQDecoder(pipe)
(p / 'vae_decoder').mkdir(parents=True, exist_ok=True)
check_inputs = [(torch.rand(1, 4, h, w),), (torch.rand(1, 4, h, w),)]
traced_model = torch.jit.trace(vae_decoder, check_inputs[0], check_inputs=[check_inputs[1]])
torch.onnx.export(traced_model, check_inputs[0], "onnx/vae_decoder/vae_decoder.onnx",
input_names=["latents"], opset_version=16)
scale = 8
(p / 'vae_encoder').mkdir(parents=True, exist_ok=True)
check_inputs = [(torch.rand(1, 3, h * scale, w * scale),), (torch.rand(1, 3, h * scale, w * scale),)]
traced_model = torch.jit.trace(vae_encoder, check_inputs[0], check_inputs=[check_inputs[1]])
torch.onnx.export(traced_model, check_inputs[0], "onnx/vae_encoder/vae_encoder.onnx",
input_names=["init_image"], opset_version=16)
convert_to_onnx(pipe.unet, pipe.vae.post_quant_conv, pipe.vae.decoder, text_encoder, height=512, width=512)
@ClashSAN
Copy link

@ShadowPower could I have some advice? I tried converting sd 1.5 using harishanand95/diffusers, it fails a line 36.

log: (Click to expand:)
Traceback (most recent call last):
File "/home/CS/Documents/test/onnx2ir/sd_onnx.py", line 36, in <module>
pipe = StableDiffusionPipeline.from_pretrained("weights/stable-diffusion-v1-5", scheduler=lms, use_auth_token=True)
File "/home/CS/Documents/test/onnx2ir/venv/lib/python3.9/site-packages/diffusers/pipeline_utils.py", line 240, in from_pretrained
load_method = getattr(class_obj, load_method_name)
TypeError: getattr(): attribute name must be string

Today, current onnx models' unet are weights.pb, it cannot be converted to IR with the openvino-dev toolkit. but it seems, vae encoder and decoder can.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment