Created
August 3, 2022 19:15
-
-
Save josephrocca/7dff488a71b55266c8d603709ecffc98 to your computer and use it in GitHub Desktop.
Minimal dalle-mini inference and tflite conversion (produces zero-byte tflite model)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/josephrocca/7dff488a71b55266c8d603709ecffc98/minimal-dalle-mini-inference-and-tflite-conversion-produces-zero-byte-tflite-model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# This notebook is based on work done by @kuprel: https://github.com/kuprel/min-dalle" | |
], | |
"metadata": { | |
"id": "snN3kT1PWRp3" | |
}, | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install jax[cpu]==0.3.14 # since, as of writing, Colab currently uses JAX v0.3.8 and that has a bug in jax2tf" | |
], | |
"metadata": { | |
"id": "th9znbh3dBKN", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "5b7ab180-12a5-4df4-a7c5-2f89f68edd9b" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Requirement already satisfied: jax[cpu]==0.3.14 in /usr/local/lib/python3.7/dist-packages (0.3.14)\n", | |
"Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax[cpu]==0.3.14) (1.7.3)\n", | |
"Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax[cpu]==0.3.14) (0.6.0)\n", | |
"Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.7/dist-packages (from jax[cpu]==0.3.14) (1.21.6)\n", | |
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax[cpu]==0.3.14) (3.3.0)\n", | |
"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax[cpu]==0.3.14) (1.2.0)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax[cpu]==0.3.14) (4.1.1)\n", | |
"Requirement already satisfied: jaxlib==0.3.14 in /usr/local/lib/python3.7/dist-packages (from jax[cpu]==0.3.14) (0.3.14+cuda11.cudnn805)\n", | |
"Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.3.14->jax[cpu]==0.3.14) (2.0)\n", | |
"Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax[cpu]==0.3.14) (5.9.0)\n", | |
"Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax[cpu]==0.3.14) (3.8.1)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"id": "ix_xt4X1_6F4", | |
"cellView": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "6fa298d0-baa2-4cb3-d382-107664982832" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Cloning into 'min-dalle'...\n", | |
"remote: Enumerating objects: 28, done.\u001b[K\n", | |
"remote: Counting objects: 100% (28/28), done.\u001b[K\n", | |
"remote: Compressing objects: 100% (25/25), done.\u001b[K\n", | |
"remote: Total 28 (delta 2), reused 13 (delta 1), pack-reused 0\u001b[K\n", | |
"Unpacking objects: 100% (28/28), done.\n", | |
"Note: checking out '1e18ba0ffa0788a987db6a439471e27e6f8e91ac'.\n", | |
"\n", | |
"You are in 'detached HEAD' state. You can look around, make experimental\n", | |
"changes and commit them, and you can discard any commits you make in this\n", | |
"state without impacting any branches by performing another checkout.\n", | |
"\n", | |
"If you want to create a new branch to retain commits you create, you may\n", | |
"do so (now or later) by using -b with the checkout command again. Example:\n", | |
"\n", | |
" git checkout -b <new-branch-name>\n", | |
"\n", | |
" % Total % Received % Xferd Average Speed Time Time Time Current\n", | |
" Dload Upload Total Spent Left Speed\n", | |
"100 234 100 234 0 0 609 0 --:--:-- --:--:-- --:--:-- 609\n", | |
"100 290M 100 290M 0 0 10.8M 0 0:00:26 0:00:26 --:--:-- 11.5M\n", | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (1.12.0+cu113)\n", | |
"Collecting flax==0.4.2\n", | |
" Downloading flax-0.4.2-py3-none-any.whl (186 kB)\n", | |
"\u001b[K |████████████████████████████████| 186 kB 13.0 MB/s \n", | |
"\u001b[?25hCollecting wandb\n", | |
" Downloading wandb-0.13.0-py2.py3-none-any.whl (1.8 MB)\n", | |
"\u001b[K |████████████████████████████████| 1.8 MB 69.5 MB/s \n", | |
"\u001b[?25hRequirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from flax==0.4.2) (4.1.1)\n", | |
"Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from flax==0.4.2) (1.21.6)\n", | |
"Collecting optax\n", | |
" Downloading optax-0.1.3-py3-none-any.whl (145 kB)\n", | |
"\u001b[K |████████████████████████████████| 145 kB 71.9 MB/s \n", | |
"\u001b[?25hRequirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax==0.4.2) (1.0.4)\n", | |
"Requirement already satisfied: jax>=0.3 in /usr/local/lib/python3.7/dist-packages (from flax==0.4.2) (0.3.14)\n", | |
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax==0.4.2) (3.2.2)\n", | |
"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax==0.4.2) (1.2.0)\n", | |
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax==0.4.2) (3.3.0)\n", | |
"Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax==0.4.2) (0.6.0)\n", | |
"Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax==0.4.2) (1.7.3)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from wandb) (57.4.0)\n", | |
"Requirement already satisfied: protobuf<4.0dev,>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (3.17.3)\n", | |
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from wandb) (3.13)\n", | |
"Collecting shortuuid>=0.5.0\n", | |
" Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)\n", | |
"Collecting GitPython>=1.0.0\n", | |
" Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)\n", | |
"\u001b[K |████████████████████████████████| 181 kB 68.9 MB/s \n", | |
"\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.3)\n", | |
"Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (5.4.8)\n", | |
"Collecting docker-pycreds>=0.4.0\n", | |
" Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", | |
"Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.23.0)\n", | |
"Collecting pathtools\n", | |
" Downloading pathtools-0.1.2.tar.gz (11 kB)\n", | |
"Collecting sentry-sdk>=1.0.0\n", | |
" Downloading sentry_sdk-1.9.0-py2.py3-none-any.whl (156 kB)\n", | |
"\u001b[K |████████████████████████████████| 156 kB 68.2 MB/s \n", | |
"\u001b[?25hRequirement already satisfied: six>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.15.0)\n", | |
"Requirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (7.1.2)\n", | |
"Collecting setproctitle\n", | |
" Downloading setproctitle-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n", | |
"Collecting gitdb<5,>=4.0.1\n", | |
" Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)\n", | |
"\u001b[K |████████████████████████████████| 63 kB 1.7 MB/s \n", | |
"\u001b[?25hCollecting smmap<6,>=3.0.1\n", | |
" Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2.10)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (1.24.3)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (3.0.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2022.6.15)\n", | |
"Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3->flax==0.4.2) (5.9.0)\n", | |
"Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3->flax==0.4.2) (3.8.1)\n", | |
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.4.2) (1.4.4)\n", | |
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.4.2) (3.0.9)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.4.2) (0.11.0)\n", | |
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.4.2) (2.8.2)\n", | |
"Collecting chex>=0.0.4\n", | |
" Downloading chex-0.1.3-py3-none-any.whl (72 kB)\n", | |
"\u001b[K |████████████████████████████████| 72 kB 629 kB/s \n", | |
"\u001b[?25hRequirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.4.2) (0.3.14+cuda11.cudnn805)\n", | |
"Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.4.2) (0.1.7)\n", | |
"Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.4.2) (0.12.0)\n", | |
"Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.4.2) (2.0)\n", | |
"Building wheels for collected packages: pathtools\n", | |
" Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=7e5a9310543b75c6ad2a08c0fc29859dfe120c58637cb374021f2eeaf85ffb6c\n", | |
" Stored in directory: /root/.cache/pip/wheels/3e/31/09/fa59cef12cdcfecc627b3d24273699f390e71828921b2cbba2\n", | |
"Successfully built pathtools\n", | |
"Installing collected packages: smmap, gitdb, chex, shortuuid, setproctitle, sentry-sdk, pathtools, optax, GitPython, docker-pycreds, wandb, flax\n", | |
"Successfully installed GitPython-3.1.27 chex-0.1.3 docker-pycreds-0.4.0 flax-0.4.2 gitdb-4.0.9 optax-0.1.3 pathtools-0.1.2 sentry-sdk-1.9.0 setproctitle-1.3.0 shortuuid-1.0.9 smmap-5.0.0 wandb-0.13.0\n", | |
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n", | |
"\u001b[34m\u001b[1mwandb\u001b[0m: Downloading dataset artifact dalle-mini/dalle-mini/mini-1:v0\n", | |
"\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:22.7\n", | |
"\u001b[34m\u001b[1mwandb\u001b[0m: Artifact downloaded to /content/min-dalle/pretrained/dalle_bart_mini\n" | |
] | |
} | |
], | |
"source": [ | |
"!git clone --depth 1 --branch 0.1.1 https://github.com/kuprel/min-dalle\n", | |
"!mkdir -p /content/min-dalle/pretrained/vqgan/\n", | |
"!curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output /content/min-dalle/pretrained/vqgan/flax_model.msgpack\n", | |
"!pip install torch flax==0.4.2 wandb\n", | |
"!wandb login --anonymously\n", | |
"!wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%cd /content/min-dalle/min_dalle" | |
], | |
"metadata": { | |
"id": "EuEPj3zBIkkm", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "82b2ccc6-c466-4c9b-8509-92e3128549a3" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content/min-dalle/min_dalle\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from flax import traverse_util, serialization\n", | |
"from typing import Dict, Tuple, List\n", | |
"import jax\n", | |
"from jax import numpy as jnp\n", | |
"import numpy\n", | |
"import os\n", | |
"import json\n", | |
"from PIL import Image\n", | |
"import torch" | |
], | |
"metadata": { | |
"id": "ry9H5pKSQUy9" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:\n", | |
" with open(os.path.join(path, \"flax_model.msgpack\"), \"rb\") as f:\n", | |
" params = serialization.msgpack_restore(f.read())\n", | |
"\n", | |
" for codec in ['encoder', 'decoder']:\n", | |
" k = 'FlaxBart{}Layers'.format(codec.title())\n", | |
" P: dict = params['model'][codec]['layers'][k]\n", | |
" P['pre_self_attn_layer_norm'] = P.pop('LayerNorm_0')\n", | |
" P['self_attn_layer_norm'] = P.pop('LayerNorm_1')\n", | |
" P['self_attn'] = P.pop('FlaxBartAttention_0')\n", | |
" if codec == 'decoder':\n", | |
" P['pre_encoder_attn_layer_norm'] = P.pop('LayerNorm_2')\n", | |
" P['encoder_attn_layer_norm'] = P.pop('LayerNorm_3')\n", | |
" P['encoder_attn'] = P.pop('FlaxBartAttention_1')\n", | |
" P['glu']: dict = P.pop('GLU_0')\n", | |
" P['glu']['ln0'] = P['glu'].pop('LayerNorm_0')\n", | |
" P['glu']['ln1'] = P['glu'].pop('LayerNorm_1')\n", | |
" P['glu']['fc0'] = P['glu'].pop('Dense_0')\n", | |
" P['glu']['fc1'] = P['glu'].pop('Dense_1')\n", | |
" P['glu']['fc2'] = P['glu'].pop('Dense_2')\n", | |
"\n", | |
" for codec in ['encoder', 'decoder']:\n", | |
" layers_params = params['model'][codec].pop('layers')\n", | |
" params['model'][codec] = {\n", | |
" **params['model'][codec], \n", | |
" **layers_params\n", | |
" }\n", | |
" \n", | |
" model_params = params.pop('model')\n", | |
" params = {**params, **model_params}\n", | |
"\n", | |
" params['decoder']['lm_head'] = params.pop('lm_head')\n", | |
"\n", | |
" return params" | |
], | |
"metadata": { | |
"id": "0rFl-rxiNpvn" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from typing import Tuple, List\n", | |
"def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:\n", | |
" print(\"parsing metadata from {}\".format(path))\n", | |
" for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:\n", | |
" assert(os.path.exists(os.path.join(path, f)))\n", | |
" with open(path + '/config.json', 'r') as f: \n", | |
" config = json.load(f)\n", | |
" with open(path + '/vocab.json') as f:\n", | |
" vocab = json.load(f)\n", | |
" with open(path + '/merges.txt') as f:\n", | |
" merges = f.read().split(\"\\n\")[1:-1]\n", | |
" return config, vocab, merges\n", | |
"\n", | |
"model_name = 'mini' # or 'mega'\n", | |
"model_path = '../pretrained/dalle_bart_{}'.format(model_name)\n", | |
"config, vocab, merges = load_dalle_bart_metadata(model_path)\n", | |
"params_dalle_bart = load_dalle_bart_flax_params(model_path)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "qSR7lFawNk-o", | |
"outputId": "577fcb48-2c72-44ac-cbc7-d4945e984446" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"parsing metadata from ../pretrained/dalle_bart_mini\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install tokenizers" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Asn9rrNtg8IG", | |
"outputId": "1d79a959-5df5-41e9-c6fb-49005a2ed61e" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting tokenizers\n", | |
" Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n", | |
"\u001b[K |████████████████████████████████| 6.6 MB 12.9 MB/s \n", | |
"\u001b[?25hInstalling collected packages: tokenizers\n", | |
"Successfully installed tokenizers-0.12.1\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from models.dalle_bart_encoder_flax import DalleBartEncoderFlax\n", | |
"from models.dalle_bart_decoder_flax import DalleBartDecoderFlax\n", | |
"\n", | |
"dalle_bart_encoder_flax_model = DalleBartEncoderFlax(\n", | |
" attention_head_count = config['encoder_attention_heads'],\n", | |
" embed_count = config['d_model'],\n", | |
" glu_embed_count = config['encoder_ffn_dim'],\n", | |
" text_token_count = config['max_text_length'],\n", | |
" text_vocab_count = config['encoder_vocab_size'],\n", | |
" layer_count = config['encoder_layers']\n", | |
").bind({'params': params_dalle_bart.pop('encoder')})\n", | |
"\n", | |
"dalle_bart_decoder_flax_model = DalleBartDecoderFlax(\n", | |
" image_token_count = config['image_length'],\n", | |
" text_token_count = config['max_text_length'],\n", | |
" image_vocab_count = config['image_vocab_size'],\n", | |
" attention_head_count = config['decoder_attention_heads'],\n", | |
" embed_count = config['d_model'],\n", | |
" glu_embed_count = config['decoder_ffn_dim'],\n", | |
" layer_count = config['decoder_layers'],\n", | |
" start_token = config['decoder_start_token_id']\n", | |
")" | |
], | |
"metadata": { | |
"id": "V0LCW-ViDiSg" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:\n", | |
" with open(os.path.join(path, 'flax_model.msgpack'), \"rb\") as f:\n", | |
" params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())\n", | |
"\n", | |
" P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.')\n", | |
"\n", | |
" for i in list(P.keys()):\n", | |
" j = i\n", | |
" if 'up' in i or 'down' in i:\n", | |
" j = i.replace('_', '.')\n", | |
" j = j.replace('proj.out', 'proj_out')\n", | |
" j = j.replace('nin.short', 'nin_short')\n", | |
" if 'bias' in i:\n", | |
" P[j] = P.pop(i)\n", | |
" elif 'scale' in i:\n", | |
" j = j.replace('scale', 'weight')\n", | |
" P[j] = P.pop(i)\n", | |
" elif 'kernel' in i:\n", | |
" j = j.replace('kernel', 'weight')\n", | |
" P[j] = P.pop(i).transpose(3, 2, 0, 1)\n", | |
"\n", | |
" for i in P:\n", | |
" P[i] = torch.tensor(P[i])\n", | |
"\n", | |
" P['embedding.weight'] = P.pop('quantize.embedding.embedding')\n", | |
"\n", | |
" for i in list(P):\n", | |
" if i.split('.')[0] in ['encoder', 'quant_conv']:\n", | |
" P.pop(i)\n", | |
" \n", | |
" return P" | |
], | |
"metadata": { | |
"id": "vMpgNCGOK-N4" | |
}, | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from torch import LongTensor, FloatTensor\n", | |
"from models.vqgan_detokenizer import VQGanDetokenizer\n", | |
"\n", | |
"def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:\n", | |
" print(\"detokenizing image\")\n", | |
" model_path = '../pretrained/vqgan'\n", | |
" params = load_vqgan_torch_params(model_path)\n", | |
" detokenizer = VQGanDetokenizer()\n", | |
" detokenizer.load_state_dict(params)\n", | |
" if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()\n", | |
" image = detokenizer.forward(image_tokens).to(torch.uint8)\n", | |
" del detokenizer, params\n", | |
" return image.to('cpu').detach().numpy()" | |
], | |
"metadata": { | |
"id": "pHrSOkC8KqwU" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from math import inf\n", | |
"from typing import List, Tuple\n", | |
"\n", | |
"\n", | |
"class TextTokenizer:\n", | |
" def __init__(self, vocab: dict, merges: List[str]):\n", | |
" self.token_from_subword = vocab\n", | |
" pairs = [tuple(pair.split()) for pair in merges]\n", | |
" self.rank_from_pair = dict(zip(pairs, range(len(pairs))))\n", | |
"\n", | |
" def __call__(self, text: str) -> List[int]:\n", | |
" sep_token = self.token_from_subword['</s>']\n", | |
" cls_token = self.token_from_subword['<s>']\n", | |
" unk_token = self.token_from_subword['<unk>']\n", | |
" text = text.lower().encode(\"ascii\", errors=\"ignore\").decode()\n", | |
" tokens = [\n", | |
" self.token_from_subword.get(subword, unk_token)\n", | |
" for word in text.split(\" \") if len(word) > 0\n", | |
" for subword in self.get_byte_pair_encoding(word)\n", | |
" ]\n", | |
" return [cls_token] + tokens + [sep_token]\n", | |
"\n", | |
" def get_byte_pair_encoding(self, word: str) -> List[str]:\n", | |
" def get_pair_rank(pair: Tuple[str, str]) -> int:\n", | |
" return self.rank_from_pair.get(pair, inf)\n", | |
"\n", | |
" subwords = [chr(ord(\" \") + 256)] + list(word)\n", | |
" while len(subwords) > 1:\n", | |
" pairs = list(zip(subwords[:-1], subwords[1:]))\n", | |
" pair_to_merge = min(pairs, key=get_pair_rank)\n", | |
" if pair_to_merge not in self.rank_from_pair: break\n", | |
" i = pairs.index(pair_to_merge)\n", | |
" subwords = (\n", | |
" (subwords[:i] if i > 0 else []) + \n", | |
" [subwords[i] + subwords[i + 1]] + \n", | |
" (subwords[i + 2:] if i + 2 < len(subwords) else [])\n", | |
" )\n", | |
"\n", | |
" # print(subwords)\n", | |
" return subwords" | |
], | |
"metadata": { | |
"id": "KAUjwzWYuKru" | |
}, | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def tokenize_text(\n", | |
" text: str, \n", | |
" config: dict,\n", | |
" vocab: dict,\n", | |
" merges: List[str]\n", | |
") -> numpy.ndarray:\n", | |
" tokens = TextTokenizer(vocab, merges)(text)\n", | |
" text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)\n", | |
" text_tokens[0, :len(tokens)] = tokens\n", | |
" text_tokens[1, :2] = [tokens[0], tokens[-1]]\n", | |
" return text_tokens" | |
], | |
"metadata": { | |
"id": "28vowjQ1B0aw" | |
}, | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def generate_image(text_tokens, seed):\n", | |
" encoder_state = dalle_bart_encoder_flax_model(text_tokens)\n", | |
" image_tokens = dalle_bart_decoder_flax_model.sample_image_tokens(\n", | |
" text_tokens,\n", | |
" encoder_state,\n", | |
" jax.random.PRNGKey(seed[0]),\n", | |
" params_dalle_bart['decoder'],\n", | |
" )\n", | |
" return image_tokens" | |
], | |
"metadata": { | |
"id": "xFGb3Gs1Th1N" | |
}, | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"generate_image_jitted = jax.jit(generate_image)" | |
], | |
"metadata": { | |
"id": "hxLZepOUT-qi" | |
}, | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# text: \"game concept art with a lush green village surrounded by a dry orange canyon with a hill in the background and a blue sky, digital art\"\n", | |
"# text tokens: [[0,880,3319,241,208,58,21843,899,2595,30419,185,58,3441,2566,7308,208,58,2349,91,99,1396,128,58,789,1955,11,1189,241,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],[0,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]]\n", | |
"# Image output for seed=2: https://i.imgur.com/9fBoJg8.png\n", | |
"\n", | |
"text = \"game concept art with a lush green village surrounded by a dry orange canyon with a hill in the background and a blue sky, digital art\"\n", | |
"seed = jnp.array([2])\n", | |
"\n", | |
"text_tokens = tokenize_text(text, config, vocab, merges)\n", | |
"print(\"text tokens shape:\", text_tokens.shape)\n", | |
"\n", | |
"image_tokens = generate_image_jitted(text_tokens, seed)\n", | |
"\n", | |
"image_tokens_numpy = numpy.array(image_tokens)\n", | |
"# print(\"image tokens\", list(image_tokens))\n", | |
"\n", | |
"image = detokenize_torch(torch.tensor(image_tokens_numpy), is_torch=False)\n", | |
"display(Image.fromarray(image))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 327 | |
}, | |
"id": "CEUUtqNNF7L5", | |
"outputId": "df6093bd-908b-46a5-d042-dbc09f178af6" | |
}, | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"text tokens shape: (2, 64)\n", | |
"detokenizing image\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<PIL.Image.Image image mode=RGB size=256x256 at 0x7F67569E2510>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Below is JAX-to-tflite conversion process" | |
], | |
"metadata": { | |
"id": "h5GG1Qw1SqDT" | |
}, | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from jax.experimental import jax2tf\n", | |
"import tensorflow as tf" | |
], | |
"metadata": { | |
"id": "kVF1_1Gqnwwi" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"my_model = tf.Module()\n", | |
"generate_image_jitted_tf = jax2tf.convert(generate_image_jitted, enable_xla=False)\n", | |
"my_model.f = tf.function(generate_image_jitted_tf, autograph=False, input_signature=[\n", | |
" tf.TensorSpec(shape=[2, 64], dtype=tf.int32, name=\"text_tokens\"),\n", | |
" tf.TensorSpec(shape=[1], dtype=tf.int32, name=\"seed\"),\n", | |
"])\n", | |
"tf.saved_model.save(my_model, '/content/dalle-mini-tfsavedmodel', options=tf.saved_model.SaveOptions(experimental_custom_gradients=False))\n", | |
"# restored_model = tf.saved_model.load('/content/dalle-mini-tfsavedmodel')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "-74vChXbStPY", | |
"outputId": "5ad31b85-7c4d-4bde-b1fd-17e90f4941b4" | |
}, | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# The notebook may crash when running the next cell due to memory usage. If so, just run the cell again after the runtime reloads.\n", | |
"# The above-generated SaveModel will persist between crashes, so no need to run any of the above cells again." | |
], | |
"metadata": { | |
"id": "zccB03McIPH5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# This doesn't work due to a bug - resulting model is zero bytes: https://github.com/tensorflow/tensorflow/issues/56629\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"converter = tf.lite.TFLiteConverter.from_saved_model(\"/content/dalle-mini-tfsavedmodel\")\n", | |
"converter.target_spec.supported_ops = [\n", | |
" tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n", | |
" tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n", | |
"]\n", | |
"tflite_model = converter.convert()\n", | |
"with open('/content/dalle-mini.tflite', 'wb') as f:\n", | |
" f.write(tflite_model)" | |
], | |
"metadata": { | |
"id": "cD4UtFEXe_4h", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "7f4b34e8-5068-4a6c-c089-4133407278cc" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:absl:Importing a function (__inference_converted_fun_6851) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n", | |
"WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# This approach doesn't work due to a weird bug \"Jax transforms and Flax models cannot be mixed\": https://github.com/tensorflow/tensorflow/issues/56660\n", | |
"\n", | |
"# converter = tf.lite.TFLiteConverter.experimental_from_jax([generate_image_jitted], [[\n", | |
"# ('text_tokens', jnp.zeros((2, 64))),\n", | |
"# ('seed', jnp.zeros((1,))),\n", | |
"# ]])\n", | |
"# tflite_model = converter.convert()\n", | |
"# with open('/content/dalle-mini.tflite', 'wb') as f:\n", | |
"# f.write(tflite_model)" | |
], | |
"metadata": { | |
"id": "yfH6UYUTtcQh" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"collapsed_sections": [], | |
"name": "Minimal dalle-mini inference and tflite conversion (produces zero-byte tflite model)", | |
"provenance": [], | |
"machine_shape": "hm", | |
"include_colab_link": true | |
}, | |
"gpuClass": "standard", | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment