Created
November 13, 2022 13:26
-
-
Save ShadowPower/1632b77626f863c860130ec4cddf20d5 to your computer and use it in GitHub Desktop.
stable diffusion onnx exporter for bes-dev
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Use https://github.com/harishanand95/diffusers to export | |
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler, DiffusionPipeline | |
from transformers import CLIPTextModel | |
class VQEncoder(nn.Module): | |
def __init__(self, stable_diffusion: DiffusionPipeline): | |
super().__init__() | |
self.vae = stable_diffusion.vae | |
def forward(self, x): | |
latent = self.vae.encode(x) | |
return latent.parameters | |
class VQDecoder(nn.Module): | |
def __init__(self, stable_diffusion: DiffusionPipeline): | |
super().__init__() | |
self.vae = stable_diffusion.vae | |
def forward(self, h): | |
return self.vae.decode(h * 5.489980697631836) | |
text_encoder = CLIPTextModel.from_pretrained("weights/stable-diffusion-v1-4/text_encoder", return_dict=False) | |
lms = LMSDiscreteScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear" | |
) | |
pipe = StableDiffusionPipeline.from_pretrained("weights/stable-diffusion-v1-4", scheduler=lms, use_auth_token=True) | |
def convert_to_onnx(unet, post_quant_conv, decoder, text_encoder, height=512, width=512): | |
p = Path("onnx/") | |
p.mkdir(parents=True, exist_ok=True) | |
if height % 8 != 0 or width % 8 != 0: | |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
h, w = height // 8, width // 8 | |
# unet onnx export | |
check_inputs = [(torch.rand(2, 4, h, w), torch.tensor(980, dtype=torch.float64), torch.rand(2, 77, 768)), | |
(torch.rand(2, 4, h, w), torch.tensor(910, dtype=torch.float64), torch.rand(2, 12, 768)), | |
# batch change, text embed with no trunc | |
] | |
traced_model = torch.jit.trace(unet, check_inputs[0], check_inputs=[check_inputs[0]], strict=True) | |
torch.onnx.export(traced_model, check_inputs[0], "onnx/unet.onnx", | |
input_names=["latent_model_input", "t", "encoder_hidden_states"], | |
opset_version=16) | |
h, w = 64, 64 | |
vae_encoder = VQEncoder(pipe) | |
vae_decoder = VQDecoder(pipe) | |
(p / 'vae_decoder').mkdir(parents=True, exist_ok=True) | |
check_inputs = [(torch.rand(1, 4, h, w),), (torch.rand(1, 4, h, w),)] | |
traced_model = torch.jit.trace(vae_decoder, check_inputs[0], check_inputs=[check_inputs[1]]) | |
torch.onnx.export(traced_model, check_inputs[0], "onnx/vae_decoder/vae_decoder.onnx", | |
input_names=["latents"], opset_version=16) | |
scale = 8 | |
(p / 'vae_encoder').mkdir(parents=True, exist_ok=True) | |
check_inputs = [(torch.rand(1, 3, h * scale, w * scale),), (torch.rand(1, 3, h * scale, w * scale),)] | |
traced_model = torch.jit.trace(vae_encoder, check_inputs[0], check_inputs=[check_inputs[1]]) | |
torch.onnx.export(traced_model, check_inputs[0], "onnx/vae_encoder/vae_encoder.onnx", | |
input_names=["init_image"], opset_version=16) | |
convert_to_onnx(pipe.unet, pipe.vae.post_quant_conv, pipe.vae.decoder, text_encoder, height=512, width=512) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@ShadowPower could I have some advice? I tried converting sd 1.5 using harishanand95/diffusers, it fails a line 36.
log: (Click to expand:)
Today, current onnx models' unet are weights.pb, it cannot be converted to IR with the openvino-dev toolkit. but it seems, vae encoder and decoder can.