Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created April 7, 2023 17:40
Show Gist options
  • Save pashu123/7ee3cf880d4acd0b7c5d81a9dece36b5 to your computer and use it in GitHub Desktop.
Save pashu123/7ee3cf880d4acd0b7c5d81a9dece36b5 to your computer and use it in GitHub Desktop.
import torch
from diffusers import StableDiffusionPipeline
import torch_mlir
from shark.shark_importer import import_with_fx
import os
import torch.fx as fx
import sys
model_input = {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 128, 128),),
"unet": (
torch.randn(2, 4, 96, 96), # latents
torch.tensor([1]).float(), # timestep
torch.randn(2, 77, 1024), # embedding
),
}
def compile_via_shark(model, inputs):
# import torch_mlir
# import io
# bytecode_stream = io.BytesIO()
# import sys
# linalg_ir = torch_mlir.compile(model, inputs, output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
# linalg_ir.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
is_f16 = True
input_mask = [True, True, True]
bytecode = import_with_fx(model, inputs, is_f16=is_f16, f16_input_mask=input_mask)
# print(bytecode.graph)
# print(bytecode(*inputs))
# return bytecode
# with open(os.path.join("xyz.mlir"), "wb") as mlir_file:
# mlir_file.write(bytecode[0])
sys.exit()
# fx_g = fx.symbolic_trace(model)
# print(fx_g.graph)
# bytecode = import_with_fx(model, inputs)
# return bytecode
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode[0], device="vulkan", mlir_dialect="tm_tensor",
)
# extra_args = ['--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))', '--iree-spirv-index-bits=64']
shark_module.compile(extra_args=[])
return shark_module
class UNetWrapper(torch.nn.Module):
def __init__(self, shark_unet):
super().__init__()
self.wrapped_unet = shark_unet
self.in_channels = None
self.device = None
self.config = None
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
# sample_np = sample.detach().cpu().numpy()
# timestep_np = timestep.half().detach().cpu().reshape(-1).numpy()
# encoder_hidden_states_np = encoder_hidden_states.detach().cpu().numpy()
# inputs = [sample_np, timestep_np, encoder_hidden_states_np]
sample = self.wrapped_unet(sample, timestep, encoder_hidden_states)
# rest of the pipeline is always in float16
return sample
class UnetCustom(torch.nn.Module):
def __init__(self, pipe_unet):
super().__init__()
self.unet = pipe_unet
self.in_channels = None
self.device = None
self.config = None
def forward(self, latent, timestep, text_embedding):
unet_out = self.unet.forward(
latent,
timestep,
text_embedding,
return_dict=False,
)[0]
return unet_out
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
# pipe = pipe.to("cuda")
# pipe.enable_attention_slicing()
unet_graph = UnetCustom(pipe.unet)
unet_graph.in_channels = pipe.unet.in_channels
unet_graph.device = pipe.unet.device
unet_graph.config = pipe.unet.config
del pipe.unet
pipe.unet = unet_graph
shark_unet = compile_via_shark(pipe.unet, model_input["unet"])
# shark_unet = shark_unet.cuda()
unet_graph = UNetWrapper(shark_unet)
unet_graph.in_channels = pipe.unet.in_channels
unet_graph.device = pipe.unet.device
unet_graph.config = pipe.unet.config
del pipe.unet
pipe.unet = unet_graph
# prompt = "a photo of an astronaut riding a horse on mars"
# image = pipe(prompt).images[0]
# image.save(f"astronaut_rides_horse.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment