Skip to content

Instantly share code, notes, and snippets.

@josephrocca
Last active October 17, 2022 17:26
Show Gist options
  • Save josephrocca/67b310ffca154fe906e548e794a29a65 to your computer and use it in GitHub Desktop.
Save josephrocca/67b310ffca154fe906e548e794a29a65 to your computer and use it in GitHub Desktop.
stable_diffusion_jax-to-onnx.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"machine_shape": "hm",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/josephrocca/67b310ffca154fe906e548e794a29a65/stable_diffusion_jax-to-onnx.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!grep MemTotal /proc/meminfo # Total RAM"
],
"metadata": {
"id": "LlbW06dPyRrC",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7dabf0ce-cc85-469d-bd22-4d03d9e60223"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"MemTotal: 26690640 kB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install transformers==4.23.1 huggingface_hub==0.10.0 ftfy==6.1.1 flax==0.6.1 git+https://github.com/huggingface/[email protected] git+https://github.com/onnx/[email protected]\n",
"!pip install --upgrade jax jaxlib"
],
"metadata": {
"id": "0YHLndloz1U_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from huggingface_hub.hf_api import HfFolder\n",
"HfFolder.save_token('h'+'f'+'_'+'AUxlCqSud'+'NTSgaWmE'+'jrUgRytG'+'JiBTLoYSD') # Don't worry! This key can be safely made public. It's just a read-only key for an \"empty\"/dummy Hugging Face account (temp email) that was SPECIFICALLY created to make it easier to access the Stable Diffusion model in Colab (less copy-pasting my token during many runtime resets). The `+` concatenation is just so it doesn't trigger any Github API key detection alarms, or whatever."
],
"metadata": {
"id": "1PMLFHaMyuBR"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from diffusers import FlaxStableDiffusionPipeline\n",
"import tensorflow as tf\n",
"import jax\n",
"from jax.experimental import jax2tf\n",
"from jax import numpy as jnp\n",
"import numpy as np\n",
"import tf2onnx"
],
"metadata": {
"id": "UFMtdmPeyxpi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", revision=\"bf16\", dtype=jnp.bfloat16, safety_checker=None)"
],
"metadata": {
"id": "PPtraQX34Az7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"text_encoder_params = params[\"text_encoder\"]"
],
"metadata": {
"id": "DAS1S2N82bEl"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def text_tokens_to_embeddings(text_encoder_params, prompt_ids):\n",
" text_embeddings = pipeline.text_encoder(prompt_ids, params=text_encoder_params)[0]\n",
" return text_embeddings"
],
"metadata": {
"id": "m0nGZcx_ZAnU"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"text_encoder_params_vars = tf.nest.map_structure(tf.Variable, text_encoder_params)\n",
"\n",
"text_tokens_to_embeddings_tf = lambda prompt_ids: jax2tf.convert(text_tokens_to_embeddings, with_gradient=False)(text_encoder_params_vars, prompt_ids)\n",
"\n",
"my_model = tf.Module()\n",
"my_model._variables = tf.nest.flatten(text_encoder_params_vars) # <-- Tell the model saver what are the variables.\n",
"my_model.f = tf.function(text_tokens_to_embeddings_tf, autograph=False, jit_compile=True, input_signature=[\n",
" tf.TensorSpec([1, 77], tf.int64, name=\"prompt_ids\"),\n",
"])\n",
"\n",
"tf2onnx.convert.from_function(my_model.f, input_signature=[\n",
" tf.TensorSpec([1, 77], tf.int64, name=\"prompt_ids\"),\n",
"], opset=16, output_path=\"text_tokens_to_embeddings.onnx\")"
],
"metadata": {
"id": "7RXsV8AduEjM",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "fdc60923-b5e8-4e4c-c94a-7c55b03e1ecb"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tf2onnx/tf_loader.py:715: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.compat.v1.graph_util.extract_sub_graph`\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment