Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created October 5, 2022 14:02
Show Gist options
  • Save pashu123/8637fabcbaf3b2d814c328509384200f to your computer and use it in GitHub Desktop.
Save pashu123/8637fabcbaf3b2d814c328509384200f to your computer and use it in GitHub Desktop.
from stable_diffusion_tf.stable_diffusion import get_models
from iree import runtime as ireert
from iree.compiler import tf as tfc
from iree.compiler import compile_str
import sys
from absl import app
import numpy as np
import os
import tempfile
import tensorflow as tf
import time
unet_model = get_models(512, 512)[1]
BATCH_SIZE = 1
unet_input = [tf.TensorSpec(shape=[BATCH_SIZE, 64, 64, 4],dtype=tf.float32),
tf.TensorSpec(shape=[BATCH_SIZE, 320], dtype=tf.float32),
tf.TensorSpec(shape=[BATCH_SIZE, 77, 768], dtype=tf.float32)]
class UnetModule(tf.Module):
def __init__(self):
super(UnetModule, self).__init__()
self.m = unet_model
self.m.predict = lambda x,y,z: self.m.call(x,y,z)
@tf.function(input_signature=unet_input)
def predict(self, x, y, z):
return self.m.predict(x,y,z)
compiler_module = tfc.compile_module(UnetModule(), exported_names = ["predict"], import_only=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment