Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save josephrocca/b1d4170d37202268f0e7316a199339eb to your computer and use it in GitHub Desktop.
Save josephrocca/b1d4170d37202268f0e7316a199339eb to your computer and use it in GitHub Desktop.
Generate Saved Model for DALL-E Mini BART Model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/josephrocca/b1d4170d37202268f0e7316a199339eb/generate-saved-model-for-dall-e-mini-bart-model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"# This notebook is based on work done by @kuprel: https://github.com/kuprel/min-dalle"
],
"metadata": {
"id": "snN3kT1PWRp3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install jax[cpu]==0.3.14 # since, as of writing, Colab currently uses JAX v0.3.8 and that has a bug in jax2tf"
],
"metadata": {
"id": "th9znbh3dBKN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ix_xt4X1_6F4",
"cellView": "code"
},
"outputs": [],
"source": [
"!git clone --depth 1 --branch 0.1.1 https://github.com/kuprel/min-dalle\n",
"!pip install torch flax==0.4.2 wandb\n",
"!wandb login --anonymously\n",
"!wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0"
]
},
{
"cell_type": "code",
"source": [
"%cd /content/min-dalle/min_dalle"
],
"metadata": {
"id": "EuEPj3zBIkkm",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "27aceb07-4401-4c93-f363-9049f15aaacb"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/min-dalle/min_dalle\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from flax import traverse_util, serialization\n",
"from typing import Dict, Tuple, List\n",
"import jax\n",
"from jax import numpy as jnp\n",
"import numpy\n",
"import os\n",
"import json\n",
"from PIL import Image\n",
"import torch"
],
"metadata": {
"id": "ry9H5pKSQUy9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:\n",
" with open(os.path.join(path, \"flax_model.msgpack\"), \"rb\") as f:\n",
" params = serialization.msgpack_restore(f.read())\n",
"\n",
" for codec in ['encoder', 'decoder']:\n",
" k = 'FlaxBart{}Layers'.format(codec.title())\n",
" P: dict = params['model'][codec]['layers'][k]\n",
" P['pre_self_attn_layer_norm'] = P.pop('LayerNorm_0')\n",
" P['self_attn_layer_norm'] = P.pop('LayerNorm_1')\n",
" P['self_attn'] = P.pop('FlaxBartAttention_0')\n",
" if codec == 'decoder':\n",
" P['pre_encoder_attn_layer_norm'] = P.pop('LayerNorm_2')\n",
" P['encoder_attn_layer_norm'] = P.pop('LayerNorm_3')\n",
" P['encoder_attn'] = P.pop('FlaxBartAttention_1')\n",
" P['glu']: dict = P.pop('GLU_0')\n",
" P['glu']['ln0'] = P['glu'].pop('LayerNorm_0')\n",
" P['glu']['ln1'] = P['glu'].pop('LayerNorm_1')\n",
" P['glu']['fc0'] = P['glu'].pop('Dense_0')\n",
" P['glu']['fc1'] = P['glu'].pop('Dense_1')\n",
" P['glu']['fc2'] = P['glu'].pop('Dense_2')\n",
"\n",
" for codec in ['encoder', 'decoder']:\n",
" layers_params = params['model'][codec].pop('layers')\n",
" params['model'][codec] = {\n",
" **params['model'][codec], \n",
" **layers_params\n",
" }\n",
" \n",
" model_params = params.pop('model')\n",
" params = {**params, **model_params}\n",
"\n",
" params['decoder']['lm_head'] = params.pop('lm_head')\n",
"\n",
" return params"
],
"metadata": {
"id": "0rFl-rxiNpvn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from typing import Tuple, List\n",
"def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:\n",
" print(\"parsing metadata from {}\".format(path))\n",
" for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:\n",
" assert(os.path.exists(os.path.join(path, f)))\n",
" with open(path + '/config.json', 'r') as f: \n",
" config = json.load(f)\n",
" with open(path + '/vocab.json') as f:\n",
" vocab = json.load(f)\n",
" with open(path + '/merges.txt') as f:\n",
" merges = f.read().split(\"\\n\")[1:-1]\n",
" return config, vocab, merges\n",
"\n",
"model_name = 'mini' # or 'mega'\n",
"model_path = '../pretrained/dalle_bart_{}'.format(model_name)\n",
"config, vocab, merges = load_dalle_bart_metadata(model_path)\n",
"params_dalle_bart = load_dalle_bart_flax_params(model_path)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qSR7lFawNk-o",
"outputId": "049a9937-2aa9-441f-dc97-6f04a23d6a9e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"parsing metadata from ../pretrained/dalle_bart_mini\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from models.dalle_bart_encoder_flax import DalleBartEncoderFlax\n",
"from models.dalle_bart_decoder_flax import DalleBartDecoderFlax\n",
"\n",
"dalle_bart_encoder_flax_model = DalleBartEncoderFlax(\n",
" attention_head_count = config['encoder_attention_heads'],\n",
" embed_count = config['d_model'],\n",
" glu_embed_count = config['encoder_ffn_dim'],\n",
" text_token_count = config['max_text_length'],\n",
" text_vocab_count = config['encoder_vocab_size'],\n",
" layer_count = config['encoder_layers']\n",
").bind({'params': params_dalle_bart.pop('encoder')})\n",
"\n",
"dalle_bart_decoder_flax_model = DalleBartDecoderFlax(\n",
" image_token_count = config['image_length'],\n",
" text_token_count = config['max_text_length'],\n",
" image_vocab_count = config['image_vocab_size'],\n",
" attention_head_count = config['decoder_attention_heads'],\n",
" embed_count = config['d_model'],\n",
" glu_embed_count = config['decoder_ffn_dim'],\n",
" layer_count = config['decoder_layers'],\n",
" start_token = config['decoder_start_token_id']\n",
")"
],
"metadata": {
"id": "V0LCW-ViDiSg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def generate_image(text_tokens, seed):\n",
" encoder_state = dalle_bart_encoder_flax_model(text_tokens)\n",
" image_tokens = dalle_bart_decoder_flax_model.sample_image_tokens(\n",
" text_tokens,\n",
" encoder_state,\n",
" jax.random.PRNGKey(seed[0]),\n",
" params_dalle_bart['decoder'],\n",
" )\n",
" return image_tokens"
],
"metadata": {
"id": "xFGb3Gs1Th1N"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"generate_image_jitted = jax.jit(generate_image)"
],
"metadata": {
"id": "hxLZepOUT-qi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Below tokens are for: \"game concept art with a lush green village surrounded by a dry orange canyon with a hill in the background and a blue sky, digital art\"\n",
"# Image output for seed=2: https://i.imgur.com/9fBoJg8.png\n",
"text_tokens = numpy.array([[0,880,3319,241,208,58,21843,899,2595,30419,185,58,3441,2566,7308,208,58,2349,91,99,1396,128,58,789,1955,11,1189,241,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],[0,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]])\n",
"seed = numpy.array([2])\n",
"\n",
"# image_tokens = generate_image_jitted(text_tokens, seed)"
],
"metadata": {
"id": "CEUUtqNNF7L5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from jax.experimental import jax2tf\n",
"import tensorflow as tf"
],
"metadata": {
"id": "kVF1_1Gqnwwi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"my_model = tf.Module()\n",
"generate_image_jitted_tf = jax2tf.convert(generate_image_jitted, enable_xla=False)\n",
"my_model.f = tf.function(generate_image_jitted_tf, autograph=False, input_signature=[\n",
" tf.TensorSpec(shape=[2, 64], dtype=tf.int32, name=\"text_tokens\"),\n",
" tf.TensorSpec(shape=[1], dtype=tf.int32, name=\"seed\"),\n",
"])\n",
"tf.saved_model.save(my_model, '/content/dalle-mini-tfsavedmodel', options=tf.saved_model.SaveOptions(experimental_custom_gradients=False))"
],
"metadata": {
"id": "-74vChXbStPY",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "73106da9-c3bc-48df-8aed-33f21bf3ba92"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"INFO:tensorflow:Assets written to: /content/dalle-mini-tfsavedmodel/assets\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:tensorflow:Assets written to: /content/dalle-mini-tfsavedmodel/assets\n"
]
}
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Generate Saved Model for DALL-E Mini BART Model",
"provenance": [],
"machine_shape": "hm",
"include_colab_link": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment