Skip to content

Instantly share code, notes, and snippets.

@josephrocca
Last active October 17, 2022 17:52
Show Gist options
  • Save josephrocca/bbfff037f21a81661e92ff2b46494c04 to your computer and use it in GitHub Desktop.
Save josephrocca/bbfff037f21a81661e92ff2b46494c04 to your computer and use it in GitHub Desktop.
stable_diffusion_jax-to-tflite.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"machine_shape": "hm",
"name": "stable_diffusion_jax-to-tflite.ipynb",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/josephrocca/bbfff037f21a81661e92ff2b46494c04/stable_diffusion_jax-to-onnx.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!grep MemTotal /proc/meminfo # Total RAM"
],
"metadata": {
"id": "LlbW06dPyRrC",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "30d82d4b-501a-4f5a-fbab-491c603fd627"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"MemTotal: 26690632 kB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install transformers==4.23.1 huggingface_hub==0.10.0 ftfy==6.1.1 flax==0.6.1 git+https://github.com/huggingface/[email protected] git+https://github.com/onnx/[email protected]\n",
"!pip install --upgrade jax jaxlib"
],
"metadata": {
"id": "0YHLndloz1U_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from huggingface_hub.hf_api import HfFolder\n",
"HfFolder.save_token('h'+'f'+'_'+'AUxlCqSud'+'NTSgaWmE'+'jrUgRytG'+'JiBTLoYSD') # Don't worry! This key can be safely made public. It's just a read-only key for an \"empty\"/dummy Hugging Face account (temp email) that was SPECIFICALLY created to make it easier to access the Stable Diffusion model in Colab (less copy-pasting my token during many runtime resets). The `+` concatenation is just so it doesn't trigger any Github API key detection alarms, or whatever."
],
"metadata": {
"id": "1PMLFHaMyuBR"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from diffusers import FlaxStableDiffusionPipeline\n",
"import tensorflow as tf\n",
"import jax\n",
"from jax.experimental import jax2tf\n",
"from jax import numpy as jnp\n",
"import numpy as np\n",
"import tf2onnx"
],
"metadata": {
"id": "UFMtdmPeyxpi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", revision=\"bf16\", dtype=jnp.bfloat16, safety_checker=None)"
],
"metadata": {
"id": "PPtraQX34Az7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"text_encoder_params = params[\"text_encoder\"]"
],
"metadata": {
"id": "DAS1S2N82bEl"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def text_tokens_to_embeddings(text_encoder_params, prompt_ids):\n",
" text_embeddings = pipeline.text_encoder(prompt_ids, params=text_encoder_params)[0]\n",
" return text_embeddings"
],
"metadata": {
"id": "m0nGZcx_ZAnU"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"converter = tf.lite.TFLiteConverter.experimental_from_jax([text_tokens_to_embeddings], [[('text_encoder_params', text_encoder_params), ('prompt_ids', np.ones([1,77], dtype='int64'))]])\n",
"tflite_model = converter.convert()\n",
"with open('text_tokens_to_embeddings.tflite', 'wb') as f:\n",
" f.write(tflite_model)"
],
"metadata": {
"id": "ovsxiY1fAM0m"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment