Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save josephrocca/c1e2228a525b9a6644de68edab823730 to your computer and use it in GitHub Desktop.
Save josephrocca/c1e2228a525b9a6644de68edab823730 to your computer and use it in GitHub Desktop.
jax-flax-transformations-cannot-be-mixed.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/josephrocca/c1e2228a525b9a6644de68edab823730/jax-flax-transformations-cannot-be-mixed.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "snN3kT1PWRp3"
},
"outputs": [],
"source": [
"# This notebook is based on work done by @kuprel: https://github.com/kuprel/min-dalle"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "th9znbh3dBKN",
"jupyter": {
"outputs_hidden": true
},
"tags": []
},
"outputs": [],
"source": [
"!pip install jax[cpu]==0.3.14"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "code",
"collapsed": true,
"id": "ix_xt4X1_6F4",
"jupyter": {
"outputs_hidden": true
},
"tags": []
},
"outputs": [],
"source": [
"!git clone --depth 1 --branch 0.1.1 https://github.com/kuprel/min-dalle\n",
"!mkdir -p /content/min-dalle/pretrained/vqgan/\n",
"!pip install torch flax==0.4.2 wandb\n",
"!wandb login --anonymously\n",
"!wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0"
]
},
{
"cell_type": "code",
"source": [
"%cd /content/min-dalle/min_dalle"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0IBOcqR7nOeT",
"outputId": "d704cb35-cdfe-4388-87c1-03901e61d7a1"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/min-dalle/min_dalle\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "ry9H5pKSQUy9"
},
"outputs": [],
"source": [
"from flax import traverse_util, serialization\n",
"from typing import Dict, Tuple, List\n",
"import jax\n",
"from jax import numpy as jnp\n",
"import numpy\n",
"import os\n",
"import json\n",
"from PIL import Image\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "0rFl-rxiNpvn"
},
"outputs": [],
"source": [
"def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:\n",
" with open(os.path.join(path, \"flax_model.msgpack\"), \"rb\") as f:\n",
" params = serialization.msgpack_restore(f.read())\n",
"\n",
" for codec in ['encoder', 'decoder']:\n",
" k = 'FlaxBart{}Layers'.format(codec.title())\n",
" P: dict = params['model'][codec]['layers'][k]\n",
" P['pre_self_attn_layer_norm'] = P.pop('LayerNorm_0')\n",
" P['self_attn_layer_norm'] = P.pop('LayerNorm_1')\n",
" P['self_attn'] = P.pop('FlaxBartAttention_0')\n",
" if codec == 'decoder':\n",
" P['pre_encoder_attn_layer_norm'] = P.pop('LayerNorm_2')\n",
" P['encoder_attn_layer_norm'] = P.pop('LayerNorm_3')\n",
" P['encoder_attn'] = P.pop('FlaxBartAttention_1')\n",
" P['glu']: dict = P.pop('GLU_0')\n",
" P['glu']['ln0'] = P['glu'].pop('LayerNorm_0')\n",
" P['glu']['ln1'] = P['glu'].pop('LayerNorm_1')\n",
" P['glu']['fc0'] = P['glu'].pop('Dense_0')\n",
" P['glu']['fc1'] = P['glu'].pop('Dense_1')\n",
" P['glu']['fc2'] = P['glu'].pop('Dense_2')\n",
"\n",
" for codec in ['encoder', 'decoder']:\n",
" layers_params = params['model'][codec].pop('layers')\n",
" params['model'][codec] = {\n",
" **params['model'][codec], \n",
" **layers_params\n",
" }\n",
" \n",
" model_params = params.pop('model')\n",
" params = {**params, **model_params}\n",
"\n",
" params['decoder']['lm_head'] = params.pop('lm_head')\n",
"\n",
" return params"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "qSR7lFawNk-o",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "123ef291-6df7-4d89-fabe-22c55c66a207"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"parsing metadata from ../pretrained/dalle_bart_mini\n"
]
}
],
"source": [
"from typing import Tuple, List\n",
"def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:\n",
" print(\"parsing metadata from {}\".format(path))\n",
" for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:\n",
" assert(os.path.exists(os.path.join(path, f)))\n",
" with open(path + '/config.json', 'r') as f: \n",
" config = json.load(f)\n",
" with open(path + '/vocab.json') as f:\n",
" vocab = json.load(f)\n",
" with open(path + '/merges.txt') as f:\n",
" merges = f.read().split(\"\\n\")[1:-1]\n",
" return config, vocab, merges\n",
"\n",
"model_name = 'mini' # or 'mega'\n",
"model_path = '../pretrained/dalle_bart_{}'.format(model_name)\n",
"config, vocab, merges = load_dalle_bart_metadata(model_path)\n",
"params_dalle_bart = load_dalle_bart_flax_params(model_path)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "V0LCW-ViDiSg"
},
"outputs": [],
"source": [
"from models.dalle_bart_encoder_flax import DalleBartEncoderFlax\n",
"\n",
"dalle_bart_encoder_flax_model = DalleBartEncoderFlax(\n",
" attention_head_count = config['encoder_attention_heads'],\n",
" embed_count = config['d_model'],\n",
" glu_embed_count = config['encoder_ffn_dim'],\n",
" text_token_count = config['max_text_length'],\n",
" text_vocab_count = config['encoder_vocab_size'],\n",
" layer_count = config['encoder_layers']\n",
").bind({'params': params_dalle_bart.pop('encoder')})"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true,
"id": "CEUUtqNNF7L5",
"jupyter": {
"outputs_hidden": true
},
"tags": [],
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "92722961-58a6-4c0a-bd97-1d421bd07660"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[[ 0.3001483 , -0.18762903, -0.2455164 , ..., -0.10054623,\n",
" -0.20922032, -0.32941774],\n",
" [-0.8780152 , 0.2660395 , -0.42892987, ..., -0.02442475,\n",
" 0.66716063, -1.035105 ],\n",
" [-0.15164408, 0.76809967, -0.9765533 , ..., 0.24457088,\n",
" 0.3005694 , -1.0586988 ],\n",
" ...,\n",
" [-0.57238525, 0.629053 , -0.6097891 , ..., 0.41932285,\n",
" 0.9070637 , -0.7128931 ],\n",
" [-0.12852874, 0.15357572, 0.03543653, ..., 0.7043331 ,\n",
" 1.0506325 , -0.38532436],\n",
" [-0.84500086, 0.0360685 , -0.48725915, ..., -0.02016719,\n",
" -0.45556316, -0.78608793]],\n",
"\n",
" [[ 0.16950229, 0.2292239 , -0.8539504 , ..., -0.49584222,\n",
" 0.47400555, -0.02326 ],\n",
" [ 0.7447071 , 0.20772544, -0.4697358 , ..., -1.0065404 ,\n",
" 1.1750246 , -0.02699879],\n",
" [ 1.2857782 , 0.5954191 , -1.0264751 , ..., -0.26644182,\n",
" 0.04588521, 0.42185357],\n",
" ...,\n",
" [ 0.19511828, 0.33729747, -0.69204015, ..., -0.21395059,\n",
" 0.50520253, 0.34687594],\n",
" [ 0.26091772, 0.662871 , -0.49031842, ..., -0.1388961 ,\n",
" 0.39078167, -0.12085283],\n",
" [ 0.1288566 , 0.297544 , -0.80136037, ..., -0.38650972,\n",
" 0.44429514, 0.14380088]]], dtype=float32)"
]
},
"metadata": {},
"execution_count": 10
}
],
"source": [
"dalle_bart_encoder_flax_model(numpy.array([[0,880,3319,241,208,58,21843,899,2595,30419,185,58,3441,2566,7308,208,58,2349,91,99,1396,128,58,789,1955,11,1189,241,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],[0,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]]))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "kVF1_1Gqnwwi"
},
"outputs": [],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-74vChXbStPY"
},
"outputs": [],
"source": [
"## You can uncomment this to confirm that tf.saved_model.save works fine:\n",
"# my_model = tf.Module()\n",
"# dalle_bart_encoder_flax_model_tf = jax2tf.convert(dalle_bart_encoder_flax_model, enable_xla=False)\n",
"# my_model.f = tf.function(dalle_bart_encoder_flax_model_tf, autograph=False, input_signature=[\n",
"# tf.TensorSpec(shape=[2, 64], dtype=tf.int32, name=\"text_tokens\"),\n",
"# ])\n",
"# tf.saved_model.save(my_model, '/content/dalle_bart_encoder_flax_model', options=tf.saved_model.SaveOptions(experimental_custom_gradients=False))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "yfH6UYUTtcQh",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 641
},
"outputId": "c4aee95c-76e9-4dd6-bf5a-4e0eafb9a4e1"
},
"outputs": [
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-13-63d8ff885349>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m ]])\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mtflite_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconverter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/dalle_bart_encoder_flax_model.tflite'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 802\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 803\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_convert_and_export_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 804\u001b[0m \u001b[0;31m# pylint: enable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36m_convert_and_export_metrics\u001b[0;34m(self, convert_func, *args, **kwargs)\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 789\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconvert_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 790\u001b[0m \u001b[0melapsed_time_ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m1000\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1471\u001b[0m hlo_proto = xla_compuation(\n\u001b[0;32m-> 1472\u001b[0;31m *ordered_inputs).as_serialized_hlo_module_proto()\n\u001b[0m\u001b[1;32m 1473\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mcomputation_maker\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 881\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menter_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend_axis_env\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 882\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjaxtree_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mavals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 883\u001b[0m \u001b[0mjaxpr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_outfeed_rewriter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 312\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_dynamic\u001b[0;34m(fun, in_avals, debug_info, keep_inputs)\u001b[0m\n\u001b[1;32m 1877\u001b[0m jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(\n\u001b[0;32m-> 1878\u001b[0;31m fun, main, in_avals, keep_inputs=keep_inputs)\n\u001b[0m\u001b[1;32m 1879\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals, keep_inputs)\u001b[0m\n\u001b[1;32m 1891\u001b[0m \u001b[0min_tracers_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeep_inputs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mkeep\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1892\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1893\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/transforms.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1227\u001b[0m \u001b[0mfull_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf'{module_name}{method_suffix}'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1228\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamed_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclass_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfull_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1229\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 350\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_wrapped_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 351\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36m_call_wrapped_method\u001b[0;34m(self, fun, args, kwargs)\u001b[0m\n\u001b[1;32m 647\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 648\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 649\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/content/min-dalle/min_dalle/models/dalle_bart_encoder_flax.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, text_tokens)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0mencoder_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayernorm_embedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m \u001b[0mencoder_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0mencoder_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinal_ln\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/transforms.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1227\u001b[0m \u001b[0mfull_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf'{module_name}{method_suffix}'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1228\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamed_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclass_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfull_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1229\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/transforms.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[0mmodule_scopes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_module_scopes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 313\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrafo_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule_scopes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 314\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mret\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/lift.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(scope_tree, *args, **kwargs)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mscope\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mscopes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0mscope\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_trace_level\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0mscope\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_populate_collections\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/scope.py\u001b[0m in \u001b[0;36m_validate_trace_level\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_validate_trace_level\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 474\u001b[0;31m \u001b[0mtracers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_trace_level\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 475\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/tracers.py\u001b[0m in \u001b[0;36mcheck_trace_level\u001b[0;34m(base_level)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mbase_level\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mJaxTransformError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mUnfilteredStackTrace\u001b[0m: flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.JaxTransformError)\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mJaxTransformError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1471\u001b[0m hlo_proto = xla_compuation(\n\u001b[0;32m-> 1472\u001b[0;31m *ordered_inputs).as_serialized_hlo_module_proto()\n\u001b[0m\u001b[1;32m 1473\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/content/min-dalle/min_dalle/models/dalle_bart_encoder_flax.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, text_tokens)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0mencoder_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayernorm_embedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m \u001b[0mencoder_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0mencoder_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinal_ln\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/tracers.py\u001b[0m in \u001b[0;36mcheck_trace_level\u001b[0;34m(base_level)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mbase_level\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mJaxTransformError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mJaxTransformError\u001b[0m: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.JaxTransformError)",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-13-63d8ff885349>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'text_tokens'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"int32\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m ]])\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mtflite_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconverter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/dalle_bart_encoder_flax_model.tflite'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtflite_model\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 802\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 803\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_convert_and_export_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 804\u001b[0m \u001b[0;31m# pylint: enable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36m_convert_and_export_metrics\u001b[0;34m(self, convert_func, *args, **kwargs)\u001b[0m\n\u001b[1;32m 787\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_save_conversion_params_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 789\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconvert_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 790\u001b[0m \u001b[0melapsed_time_ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m1000\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 791\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/lite.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1472\u001b[0m *ordered_inputs).as_serialized_hlo_module_proto()\n\u001b[1;32m 1473\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1474\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Failed to convert the given Jax function to hlo.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1475\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1476\u001b[0m \u001b[0;31m# We need to set the hlo proto, and here we use serialized proto format\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Failed to convert the given Jax function to hlo."
]
}
],
"source": [
"# This doesn't work: flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.JaxTransformError)\n",
"\n",
"converter = tf.lite.TFLiteConverter.experimental_from_jax([dalle_bart_encoder_flax_model], [[\n",
" ('text_tokens', jnp.zeros((2, 64), dtype=\"int32\")),\n",
"]])\n",
"tflite_model = converter.convert()\n",
"with open('/content/dalle_bart_encoder_flax_model.tflite', 'wb') as f:\n",
" f.write(tflite_model)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "jax-flax-transformations-cannot-be-mixed.ipynb",
"provenance": [],
"include_colab_link": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment