Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created October 5, 2022 16:42
Show Gist options
  • Save pashu123/a25e36eba87981ec327659762db91f93 to your computer and use it in GitHub Desktop.
Save pashu123/a25e36eba87981ec327659762db91f93 to your computer and use it in GitHub Desktop.
from iree import runtime as ireert
from iree.compiler import tf as tfc
from iree.compiler import compile_str
import sys
from absl import app
import numpy as np
import os
import tempfile
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.generative.stable_diffusion.diffusion_model import DiffusionModel
import time
diffusion_model_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
)
BATCH_SIZE = 1
unet_input = [tf.TensorSpec(shape=[BATCH_SIZE, 64, 64, 4],dtype=tf.float32),
tf.TensorSpec(shape=[BATCH_SIZE, 320], dtype=tf.float32),
tf.TensorSpec(shape=[BATCH_SIZE, 77, 768], dtype=tf.float32)]
class UnetModule(tf.Module):
def __init__(self):
super(UnetModule, self).__init__()
self.m = DiffusionModel(512, 512, 77)
self.m.load_weights(diffusion_model_weights_fpath)
self.m.predict = lambda x,y,z: self.m([x,y,z])
@tf.function(input_signature=unet_input)
def predict(self, x, y, z):
return self.m.predict(x,y,z)
compiler_module = tfc.compile_module(UnetModule(), exported_names = ["predict"], import_only=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment