Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created May 18, 2023 15:19
Show Gist options
  • Save pashu123/e72468360589effa2c931ca478beba2f to your computer and use it in GitHub Desktop.
Save pashu123/e72468360589effa2c931ca478beba2f to your computer and use it in GitHub Desktop.
import torch
from diffusers import StableDiffusionPipeline
import torch_mlir
from shark.shark_importer import import_with_fx
import os
import sys
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.models.attention_processor import AttnProcessor
torch.backends.cuda.enable_flash_sdp
from torch.backends.cuda import sdp_kernel, SDPBackend
# Helpful arguments mapper
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
model_input = {
"clip": (torch.randint(1, 2, (1, 77)).cuda().half(),),
"vae": (torch.randn(1, 4, 128, 128).cuda().half(),),
"unet": (
torch.randn(2, 4, 96, 96).cuda().half(), # latents
torch.tensor([1]).float().cuda(), # timestep
torch.randn(2, 77, 1024).cuda().half(), # embedding
),
}
def compile_via_shark(model, inputs):
is_f16 = True
input_mask = [True, True, True]
# with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
bytecode = import_with_fx(model, inputs)
with open(os.path.join("unet.mlir"), "wb") as mlir_file:
mlir_file.write(bytecode[0])
sys.exit()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode[0], device="vulkan", mlir_dialect="tm_tensor",
)
# extra_args = ['--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))', '--iree-spirv-index-bits=64']
shark_module.compile(extra_args=[])
return shark_module
class UNetWrapper(torch.nn.Module):
def __init__(self, shark_unet):
super().__init__()
self.wrapped_unet = shark_unet
self.in_channels = None
self.device = None
self.config = None
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
# sample_np = sample.detach().cpu().numpy()
# timestep_np = timestep.half().detach().cpu().reshape(-1).numpy()
# encoder_hidden_states_np = encoder_hidden_states.detach().cpu().numpy()
# inputs = [sample_np, timestep_np, encoder_hidden_states_np]
sample = self.wrapped_unet(sample, timestep, encoder_hidden_states)
# rest of the pipeline is always in float16
return sample
class UnetCustom(torch.nn.Module):
def __init__(self, pipe_unet):
super().__init__()
self.unet = pipe_unet
self.in_channels = None
self.device = None
self.config = None
def forward(self, latent, timestep, text_embedding):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
return_dict=False,
)[0]
return unet_out
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16)
pipe.unet.set_attn_processor(AttnProcessor2_0())
pipe = pipe.to("cuda")
unet_graph = UnetCustom(pipe.unet)
unet_graph.in_channels = pipe.unet.in_channels
unet_graph.device = pipe.unet.device
unet_graph.config = pipe.unet.config
del pipe.unet
pipe.unet = unet_graph
shark_unet = compile_via_shark(pipe.unet, model_input["unet"])
# shark_unet = shark_unet.cuda()
unet_graph = UNetWrapper(shark_unet)
unet_graph.in_channels = pipe.unet.in_channels
unet_graph.device = pipe.unet.device
unet_graph.config = pipe.unet.config
del pipe.unet
pipe.unet = unet_graph
# prompt = "a photo of an astronaut riding a horse on mars"
# image = pipe(prompt).images[0]
# image.save(f"astronaut_rides_horse.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment