Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created March 24, 2023 12:14
Show Gist options
  • Save pashu123/e6bcdf9fc1d0daa0b635d55a589a773b to your computer and use it in GitHub Desktop.
Save pashu123/e6bcdf9fc1d0daa0b635d55a589a773b to your computer and use it in GitHub Desktop.
import torch
from diffusers import StableDiffusionPipeline
import torch_mlir
from shark.shark_importer import import_with_fx
import os
model_input = {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 128, 128),),
"unet": (
torch.randn(2, 4, 96, 96), # latents
torch.tensor([1]).float(), # timestep
torch.randn(2, 77, 1024), # embedding
),
}
def compile_via_shark(model, inputs):
# import torch_mlir
# import io
# bytecode_stream = io.BytesIO()
# import sys
# linalg_ir = torch_mlir.compile(model, inputs, output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
# linalg_ir.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
# model = model.cpu().float()
# is_f16 = True
# input_mask = [True, True, True]
# bytecode = import_with_fx(model, inputs, is_f16=is_f16, f16_input_mask=input_mask)
bytecode = import_with_fx(model.cpu(), inputs)
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode[0], device="vulkan", mlir_dialect="tm_tensor",
)
# extra_args = ['--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))', '--iree-spirv-index-bits=64']
shark_module.compile(extra_args=[])
return shark_module
class UNetWrapper(torch.nn.Module):
def __init__(self, shark_unet):
super().__init__()
self.wrapped_unet = shark_unet
self.in_channels = None
self.device = None
self.config = None
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
sample_np = sample.detach().cpu().numpy()
timestep_np = timestep.float().detach().cpu().reshape(-1).numpy()
encoder_hidden_states_np = encoder_hidden_states.detach().cpu().numpy()
inputs = [sample_np, timestep_np, encoder_hidden_states_np]
sample = self.wrapped_unet("forward", inputs)
# rest of the pipeline is always in float16
return torch.from_numpy(sample).cuda()
class UnetCustom(torch.nn.Module):
def __init__(self, pipe_unet):
super().__init__()
self.unet = pipe_unet
self.in_channels = None
self.device = None
self.config = None
def forward(self, latent, timestep, text_embedding):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
return_dict=False,
)[0]
return unet_out
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float32)
pipe = pipe.to("cuda")
# pipe.enable_attention_slicing()
unet_graph = UnetCustom(pipe.unet)
unet_graph.in_channels = pipe.unet.in_channels
unet_graph.device = pipe.unet.device
unet_graph.config = pipe.unet.config
del pipe.unet
pipe.unet = unet_graph
shark_unet = compile_via_shark(pipe.unet, model_input["unet"])
unet_graph = UNetWrapper(shark_unet)
unet_graph.in_channels = pipe.unet.in_channels
unet_graph.device = pipe.unet.device
unet_graph.config = pipe.unet.config
del pipe.unet
pipe.unet = unet_graph
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save(f"astronaut_rides_horse.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment