Created
July 2, 2022 15:55
-
-
Save josephrocca/c1e2228a525b9a6644de68edab823730 to your computer and use it in GitHub Desktop.
jax-flax-transformations-cannot-be-mixed.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/josephrocca/c1e2228a525b9a6644de68edab823730/jax-flax-transformations-cannot-be-mixed.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "snN3kT1PWRp3" | |
}, | |
"outputs": [], | |
"source": [ | |
"# This notebook is based on work done by @kuprel: https://github.com/kuprel/min-dalle" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"id": "th9znbh3dBKN", | |
"jupyter": { | |
"outputs_hidden": true | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install jax[cpu]==0.3.14" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "code", | |
"collapsed": true, | |
"id": "ix_xt4X1_6F4", | |
"jupyter": { | |
"outputs_hidden": true | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"!git clone --depth 1 --branch 0.1.1 https://github.com/kuprel/min-dalle\n", | |
"!mkdir -p /content/min-dalle/pretrained/vqgan/\n", | |
"!pip install torch flax==0.4.2 wandb\n", | |
"!wandb login --anonymously\n", | |
"!wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%cd /content/min-dalle/min_dalle" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "0IBOcqR7nOeT", | |
"outputId": "d704cb35-cdfe-4388-87c1-03901e61d7a1" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content/min-dalle/min_dalle\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"id": "ry9H5pKSQUy9" | |
}, | |
"outputs": [], | |
"source": [ | |
"from flax import traverse_util, serialization\n", | |
"from typing import Dict, Tuple, List\n", | |
"import jax\n", | |
"from jax import numpy as jnp\n", | |
"import numpy\n", | |
"import os\n", | |
"import json\n", | |
"from PIL import Image\n", | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"id": "0rFl-rxiNpvn" | |
}, | |
"outputs": [], | |
"source": [ | |
"def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:\n", | |
" with open(os.path.join(path, \"flax_model.msgpack\"), \"rb\") as f:\n", | |
" params = serialization.msgpack_restore(f.read())\n", | |
"\n", | |
" for codec in ['encoder', 'decoder']:\n", | |
" k = 'FlaxBart{}Layers'.format(codec.title())\n", | |
" P: dict = params['model'][codec]['layers'][k]\n", | |
" P['pre_self_attn_layer_norm'] = P.pop('LayerNorm_0')\n", | |
" P['self_attn_layer_norm'] = P.pop('LayerNorm_1')\n", | |
" P['self_attn'] = P.pop('FlaxBartAttention_0')\n", | |
" if codec == 'decoder':\n", | |
" P['pre_encoder_attn_layer_norm'] = P.pop('LayerNorm_2')\n", | |
" P['encoder_attn_layer_norm'] = P.pop('LayerNorm_3')\n", | |
" P['encoder_attn'] = P.pop('FlaxBartAttention_1')\n", | |
" P['glu']: dict = P.pop('GLU_0')\n", | |
" P['glu']['ln0'] = P['glu'].pop('LayerNorm_0')\n", | |
" P['glu']['ln1'] = P['glu'].pop('LayerNorm_1')\n", | |
" P['glu']['fc0'] = P['glu'].pop('Dense_0')\n", | |
" P['glu']['fc1'] = P['glu'].pop('Dense_1')\n", | |
" P['glu']['fc2'] = P['glu'].pop('Dense_2')\n", | |
"\n", | |
" for codec in ['encoder', 'decoder']:\n", | |
" layers_params = params['model'][codec].pop('layers')\n", | |
" params['model'][codec] = {\n", | |
" **params['model'][codec], \n", | |
" **layers_params\n", | |
" }\n", | |
" \n", | |
" model_params = params.pop('model')\n", | |
" params = {**params, **model_params}\n", | |
"\n", | |
" params['decoder']['lm_head'] = params.pop('lm_head')\n", | |
"\n", | |
" return params" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"id": "qSR7lFawNk-o", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "123ef291-6df7-4d89-fabe-22c55c66a207" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"parsing metadata from ../pretrained/dalle_bart_mini\n" | |
] | |
} | |
], | |
"source": [ | |
"from typing import Tuple, List\n", | |
"def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:\n", | |
" print(\"parsing metadata from {}\".format(path))\n", | |
" for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:\n", | |
" assert(os.path.exists(os.path.join(path, f)))\n", | |
" with open(path + '/config.json', 'r') as f: \n", | |
" config = json.load(f)\n", | |
" with open(path + '/vocab.json') as f:\n", | |
" vocab = json.load(f)\n", | |
" with open(path + '/merges.txt') as f:\n", | |
" merges = f.read().split(\"\\n\")[1:-1]\n", | |
" return config, vocab, merges\n", | |
"\n", | |
"model_name = 'mini' # or 'mega'\n", | |
"model_path = '../pretrained/dalle_bart_{}'.format(model_name)\n", | |
"config, vocab, merges = load_dalle_bart_metadata(model_path)\n", | |
"params_dalle_bart = load_dalle_bart_flax_params(model_path)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"id": "V0LCW-ViDiSg" | |
}, | |
"outputs": [], | |
"source": [ | |
"from models.dalle_bart_encoder_flax import DalleBartEncoderFlax\n", | |
"\n", | |
"dalle_bart_encoder_flax_model = DalleBartEncoderFlax(\n", | |
" attention_head_count = config['encoder_attention_heads'],\n", | |
" embed_count = config['d_model'],\n", | |
" glu_embed_count = config['encoder_ffn_dim'],\n", | |
" text_token_count = config['max_text_length'],\n", | |
" text_vocab_count = config['encoder_vocab_size'],\n", | |
" layer_count = config['encoder_layers']\n", | |
").bind({'params': params_dalle_bart.pop('encoder')})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true, | |
"id": "CEUUtqNNF7L5", | |
"jupyter": { | |
"outputs_hidden": true | |
}, | |
"tags": [], | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "92722961-58a6-4c0a-bd97-1d421bd07660" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([[[ 0.3001483 , -0.18762903, -0.2455164 , ..., -0.10054623,\n", | |
" -0.20922032, -0.32941774],\n", | |
" [-0.8780152 , 0.2660395 , -0.42892987, ..., -0.02442475,\n", | |
" 0.66716063, -1.035105 ],\n", | |
" [-0.15164408, 0.76809967, -0.9765533 , ..., 0.24457088,\n", | |
" 0.3005694 , -1.0586988 ],\n", | |
" ...,\n", | |
" [-0.57238525, 0.629053 , -0.6097891 , ..., 0.41932285,\n", | |
" 0.9070637 , -0.7128931 ],\n", | |
" [-0.12852874, 0.15357572, 0.03543653, ..., 0.7043331 ,\n", | |
" 1.0506325 , -0.38532436],\n", | |
" [-0.84500086, 0.0360685 , -0.48725915, ..., -0.02016719,\n", | |
" -0.45556316, -0.78608793]],\n", | |
"\n", | |
" [[ 0.16950229, 0.2292239 , -0.8539504 , ..., -0.49584222,\n", | |
" 0.47400555, -0.02326 ],\n", | |
" [ 0.7447071 , 0.20772544, -0.4697358 , ..., -1.0065404 ,\n", | |
" 1.1750246 , -0.02699879],\n", | |
" [ 1.2857782 , 0.5954191 , -1.0264751 , ..., -0.26644182,\n", | |
" 0.04588521, 0.42185357],\n", | |
" ...,\n", | |
" [ 0.19511828, 0.33729747, -0.69204015, ..., -0.21395059,\n", | |
" 0.50520253, 0.34687594],\n", | |
" [ 0.26091772, 0.662871 , -0.49031842, ..., -0.1388961 ,\n", | |
" 0.39078167, -0.12085283],\n", | |
" [ 0.1288566 , 0.297544 , -0.80136037, ..., -0.38650972,\n", | |
" 0.44429514, 0.14380088]]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
], | |
"source": [ | |
"dalle_bart_encoder_flax_model(numpy.array([[0,880,3319,241,208,58,21843,899,2595,30419,185,58,3441,2566,7308,208,58,2349,91,99,1396,128,58,789,1955,11,1189,241,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],[0,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"id": "kVF1_1Gqnwwi" | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "-74vChXbStPY" | |
}, | |
"outputs": [], | |
"source": [ | |
"## You can uncomment this to confirm that tf.saved_model.save works fine:\n", | |
"# my_model = tf.Module()\n", | |
"# dalle_bart_encoder_flax_model_tf = jax2tf.convert(dalle_bart_encoder_flax_model, enable_xla=False)\n", | |
"# my_model.f = tf.function(dalle_bart_encoder_flax_model_tf, autograph=False, input_signature=[\n", | |
"# tf.TensorSpec(shape=[2, 64], dtype=tf.int32, name=\"text_tokens\"),\n", | |
"# ])\n", | |
"# tf.saved_model.save(my_model, '/content/dalle_bart_encoder_flax_model', options=tf.saved_model.SaveOptions(experimental_custom_gradients=False))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"id": "yfH6UYUTtcQh", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 641 | |
}, | |
"outputId": "c4aee95c-76e9-4dd6-bf5a-4e0eafb9a4e1" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "ValueError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-13-63d8ff885349>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m ]])\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mtflite_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconverter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/dalle_bart_encoder_flax_model.tflite'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 802\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 803\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_convert_and_export_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 804\u001b[0m \u001b[0;31m# pylint: enable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36m_convert_and_export_metrics\u001b[0;34m(self, convert_func, *args, **kwargs)\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 789\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconvert_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 790\u001b[0m \u001b[0melapsed_time_ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m1000\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1471\u001b[0m hlo_proto = xla_compuation(\n\u001b[0;32m-> 1472\u001b[0;31m *ordered_inputs).as_serialized_hlo_module_proto()\n\u001b[0m\u001b[1;32m 1473\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mcomputation_maker\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 881\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menter_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend_axis_env\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 882\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjaxtree_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mavals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 883\u001b[0m \u001b[0mjaxpr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_outfeed_rewriter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 312\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_dynamic\u001b[0;34m(fun, in_avals, debug_info, keep_inputs)\u001b[0m\n\u001b[1;32m 1877\u001b[0m jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(\n\u001b[0;32m-> 1878\u001b[0;31m fun, main, in_avals, keep_inputs=keep_inputs)\n\u001b[0m\u001b[1;32m 1879\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals, keep_inputs)\u001b[0m\n\u001b[1;32m 1891\u001b[0m \u001b[0min_tracers_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeep_inputs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mkeep\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1892\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1893\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/transforms.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1227\u001b[0m \u001b[0mfull_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf'{module_name}{method_suffix}'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1228\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamed_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclass_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfull_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1229\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/lib/python3.7/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 350\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_wrapped_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 351\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36m_call_wrapped_method\u001b[0;34m(self, fun, args, kwargs)\u001b[0m\n\u001b[1;32m 647\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 648\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 649\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/content/min-dalle/min_dalle/models/dalle_bart_encoder_flax.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, text_tokens)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0mencoder_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayernorm_embedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m \u001b[0mencoder_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0mencoder_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinal_ln\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/transforms.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1227\u001b[0m \u001b[0mfull_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf'{module_name}{method_suffix}'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1228\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamed_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclass_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfull_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1229\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/lib/python3.7/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/transforms.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[0mmodule_scopes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_module_scopes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 313\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrafo_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule_scopes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 314\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mret\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/lift.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(scope_tree, *args, **kwargs)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mscope\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mscopes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0mscope\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_trace_level\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0mscope\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_populate_collections\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/scope.py\u001b[0m in \u001b[0;36m_validate_trace_level\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_validate_trace_level\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 474\u001b[0;31m \u001b[0mtracers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_trace_level\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 475\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/tracers.py\u001b[0m in \u001b[0;36mcheck_trace_level\u001b[0;34m(base_level)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mbase_level\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mJaxTransformError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;31mUnfilteredStackTrace\u001b[0m: flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.JaxTransformError)\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------", | |
"\nThe above exception was the direct cause of the following exception:\n", | |
"\u001b[0;31mJaxTransformError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1471\u001b[0m hlo_proto = xla_compuation(\n\u001b[0;32m-> 1472\u001b[0;31m *ordered_inputs).as_serialized_hlo_module_proto()\n\u001b[0m\u001b[1;32m 1473\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/lib/python3.7/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/content/min-dalle/min_dalle/models/dalle_bart_encoder_flax.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, text_tokens)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0mencoder_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayernorm_embedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m \u001b[0mencoder_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0mencoder_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinal_ln\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/lib/python3.7/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/tracers.py\u001b[0m in \u001b[0;36mcheck_trace_level\u001b[0;34m(base_level)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mbase_level\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mJaxTransformError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;31mJaxTransformError\u001b[0m: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.JaxTransformError)", | |
"\nDuring handling of the above exception, another exception occurred:\n", | |
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-13-63d8ff885349>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'text_tokens'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"int32\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m ]])\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mtflite_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconverter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/dalle_bart_encoder_flax_model.tflite'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtflite_model\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 802\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 803\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_convert_and_export_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 804\u001b[0m \u001b[0;31m# pylint: enable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36m_convert_and_export_metrics\u001b[0;34m(self, convert_func, *args, **kwargs)\u001b[0m\n\u001b[1;32m 787\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_save_conversion_params_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 789\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconvert_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 790\u001b[0m \u001b[0melapsed_time_ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m1000\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 791\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1472\u001b[0m *ordered_inputs).as_serialized_hlo_module_proto()\n\u001b[1;32m 1473\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1474\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Failed to convert the given Jax function to hlo.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1475\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1476\u001b[0m \u001b[0;31m# We need to set the hlo proto, and here we use serialized proto format\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mValueError\u001b[0m: Failed to convert the given Jax function to hlo." | |
] | |
} | |
], | |
"source": [ | |
"# This doesn't work: flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.JaxTransformError)\n", | |
"\n", | |
"converter = tf.lite.TFLiteConverter.experimental_from_jax([dalle_bart_encoder_flax_model], [[\n", | |
" ('text_tokens', jnp.zeros((2, 64), dtype=\"int32\")),\n", | |
"]])\n", | |
"tflite_model = converter.convert()\n", | |
"with open('/content/dalle_bart_encoder_flax_model.tflite', 'wb') as f:\n", | |
" f.write(tflite_model)" | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"collapsed_sections": [], | |
"machine_shape": "hm", | |
"name": "jax-flax-transformations-cannot-be-mixed.ipynb", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"gpuClass": "standard", | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment