Last active
July 2, 2022 12:00
-
-
Save josephrocca/b1d4170d37202268f0e7316a199339eb to your computer and use it in GitHub Desktop.
Generate Saved Model for DALL-E Mini BART Model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/josephrocca/b1d4170d37202268f0e7316a199339eb/generate-saved-model-for-dall-e-mini-bart-model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# This notebook is based on work done by @kuprel: https://github.com/kuprel/min-dalle" | |
], | |
"metadata": { | |
"id": "snN3kT1PWRp3" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install jax[cpu]==0.3.14 # since, as of writing, Colab currently uses JAX v0.3.8 and that has a bug in jax2tf" | |
], | |
"metadata": { | |
"id": "th9znbh3dBKN" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ix_xt4X1_6F4", | |
"cellView": "code" | |
}, | |
"outputs": [], | |
"source": [ | |
"!git clone --depth 1 --branch 0.1.1 https://github.com/kuprel/min-dalle\n", | |
"!pip install torch flax==0.4.2 wandb\n", | |
"!wandb login --anonymously\n", | |
"!wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%cd /content/min-dalle/min_dalle" | |
], | |
"metadata": { | |
"id": "EuEPj3zBIkkm", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "27aceb07-4401-4c93-f363-9049f15aaacb" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content/min-dalle/min_dalle\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from flax import traverse_util, serialization\n", | |
"from typing import Dict, Tuple, List\n", | |
"import jax\n", | |
"from jax import numpy as jnp\n", | |
"import numpy\n", | |
"import os\n", | |
"import json\n", | |
"from PIL import Image\n", | |
"import torch" | |
], | |
"metadata": { | |
"id": "ry9H5pKSQUy9" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:\n", | |
" with open(os.path.join(path, \"flax_model.msgpack\"), \"rb\") as f:\n", | |
" params = serialization.msgpack_restore(f.read())\n", | |
"\n", | |
" for codec in ['encoder', 'decoder']:\n", | |
" k = 'FlaxBart{}Layers'.format(codec.title())\n", | |
" P: dict = params['model'][codec]['layers'][k]\n", | |
" P['pre_self_attn_layer_norm'] = P.pop('LayerNorm_0')\n", | |
" P['self_attn_layer_norm'] = P.pop('LayerNorm_1')\n", | |
" P['self_attn'] = P.pop('FlaxBartAttention_0')\n", | |
" if codec == 'decoder':\n", | |
" P['pre_encoder_attn_layer_norm'] = P.pop('LayerNorm_2')\n", | |
" P['encoder_attn_layer_norm'] = P.pop('LayerNorm_3')\n", | |
" P['encoder_attn'] = P.pop('FlaxBartAttention_1')\n", | |
" P['glu']: dict = P.pop('GLU_0')\n", | |
" P['glu']['ln0'] = P['glu'].pop('LayerNorm_0')\n", | |
" P['glu']['ln1'] = P['glu'].pop('LayerNorm_1')\n", | |
" P['glu']['fc0'] = P['glu'].pop('Dense_0')\n", | |
" P['glu']['fc1'] = P['glu'].pop('Dense_1')\n", | |
" P['glu']['fc2'] = P['glu'].pop('Dense_2')\n", | |
"\n", | |
" for codec in ['encoder', 'decoder']:\n", | |
" layers_params = params['model'][codec].pop('layers')\n", | |
" params['model'][codec] = {\n", | |
" **params['model'][codec], \n", | |
" **layers_params\n", | |
" }\n", | |
" \n", | |
" model_params = params.pop('model')\n", | |
" params = {**params, **model_params}\n", | |
"\n", | |
" params['decoder']['lm_head'] = params.pop('lm_head')\n", | |
"\n", | |
" return params" | |
], | |
"metadata": { | |
"id": "0rFl-rxiNpvn" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from typing import Tuple, List\n", | |
"def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:\n", | |
" print(\"parsing metadata from {}\".format(path))\n", | |
" for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:\n", | |
" assert(os.path.exists(os.path.join(path, f)))\n", | |
" with open(path + '/config.json', 'r') as f: \n", | |
" config = json.load(f)\n", | |
" with open(path + '/vocab.json') as f:\n", | |
" vocab = json.load(f)\n", | |
" with open(path + '/merges.txt') as f:\n", | |
" merges = f.read().split(\"\\n\")[1:-1]\n", | |
" return config, vocab, merges\n", | |
"\n", | |
"model_name = 'mini' # or 'mega'\n", | |
"model_path = '../pretrained/dalle_bart_{}'.format(model_name)\n", | |
"config, vocab, merges = load_dalle_bart_metadata(model_path)\n", | |
"params_dalle_bart = load_dalle_bart_flax_params(model_path)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "qSR7lFawNk-o", | |
"outputId": "049a9937-2aa9-441f-dc97-6f04a23d6a9e" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"parsing metadata from ../pretrained/dalle_bart_mini\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from models.dalle_bart_encoder_flax import DalleBartEncoderFlax\n", | |
"from models.dalle_bart_decoder_flax import DalleBartDecoderFlax\n", | |
"\n", | |
"dalle_bart_encoder_flax_model = DalleBartEncoderFlax(\n", | |
" attention_head_count = config['encoder_attention_heads'],\n", | |
" embed_count = config['d_model'],\n", | |
" glu_embed_count = config['encoder_ffn_dim'],\n", | |
" text_token_count = config['max_text_length'],\n", | |
" text_vocab_count = config['encoder_vocab_size'],\n", | |
" layer_count = config['encoder_layers']\n", | |
").bind({'params': params_dalle_bart.pop('encoder')})\n", | |
"\n", | |
"dalle_bart_decoder_flax_model = DalleBartDecoderFlax(\n", | |
" image_token_count = config['image_length'],\n", | |
" text_token_count = config['max_text_length'],\n", | |
" image_vocab_count = config['image_vocab_size'],\n", | |
" attention_head_count = config['decoder_attention_heads'],\n", | |
" embed_count = config['d_model'],\n", | |
" glu_embed_count = config['decoder_ffn_dim'],\n", | |
" layer_count = config['decoder_layers'],\n", | |
" start_token = config['decoder_start_token_id']\n", | |
")" | |
], | |
"metadata": { | |
"id": "V0LCW-ViDiSg" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def generate_image(text_tokens, seed):\n", | |
" encoder_state = dalle_bart_encoder_flax_model(text_tokens)\n", | |
" image_tokens = dalle_bart_decoder_flax_model.sample_image_tokens(\n", | |
" text_tokens,\n", | |
" encoder_state,\n", | |
" jax.random.PRNGKey(seed[0]),\n", | |
" params_dalle_bart['decoder'],\n", | |
" )\n", | |
" return image_tokens" | |
], | |
"metadata": { | |
"id": "xFGb3Gs1Th1N" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"generate_image_jitted = jax.jit(generate_image)" | |
], | |
"metadata": { | |
"id": "hxLZepOUT-qi" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Below tokens are for: \"game concept art with a lush green village surrounded by a dry orange canyon with a hill in the background and a blue sky, digital art\"\n", | |
"# Image output for seed=2: https://i.imgur.com/9fBoJg8.png\n", | |
"text_tokens = numpy.array([[0,880,3319,241,208,58,21843,899,2595,30419,185,58,3441,2566,7308,208,58,2349,91,99,1396,128,58,789,1955,11,1189,241,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],[0,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]])\n", | |
"seed = numpy.array([2])\n", | |
"\n", | |
"# image_tokens = generate_image_jitted(text_tokens, seed)" | |
], | |
"metadata": { | |
"id": "CEUUtqNNF7L5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from jax.experimental import jax2tf\n", | |
"import tensorflow as tf" | |
], | |
"metadata": { | |
"id": "kVF1_1Gqnwwi" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"my_model = tf.Module()\n", | |
"generate_image_jitted_tf = jax2tf.convert(generate_image_jitted, enable_xla=False)\n", | |
"my_model.f = tf.function(generate_image_jitted_tf, autograph=False, input_signature=[\n", | |
" tf.TensorSpec(shape=[2, 64], dtype=tf.int32, name=\"text_tokens\"),\n", | |
" tf.TensorSpec(shape=[1], dtype=tf.int32, name=\"seed\"),\n", | |
"])\n", | |
"tf.saved_model.save(my_model, '/content/dalle-mini-tfsavedmodel', options=tf.saved_model.SaveOptions(experimental_custom_gradients=False))" | |
], | |
"metadata": { | |
"id": "-74vChXbStPY", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "73106da9-c3bc-48df-8aed-33f21bf3ba92" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"INFO:tensorflow:Assets written to: /content/dalle-mini-tfsavedmodel/assets\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"INFO:tensorflow:Assets written to: /content/dalle-mini-tfsavedmodel/assets\n" | |
] | |
} | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"collapsed_sections": [], | |
"name": "Generate Saved Model for DALL-E Mini BART Model", | |
"provenance": [], | |
"machine_shape": "hm", | |
"include_colab_link": true | |
}, | |
"gpuClass": "standard", | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment