Last active
June 17, 2023 11:34
-
-
Save buttercutter/1593ed1ae13e56b50c05f1d60c296204 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "PiSi90gspEQP" | |
}, | |
"source": [ | |
"# Easy GPT-Q + LoRA in JAX ([github](https://github.com/davisyoshida/easy-lora-and-gptq))\n", | |
"\n", | |
"[Davis Yoshida](https://github.com/davisyoshida/)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "hfxALa1so2JD" | |
}, | |
"source": [ | |
"This notebook shows how to combine two JAX tools/transforms I wrote: [Lorax](https://github.com/davisyoshida/lorax) and [JAX-GPTQ](https://github.com/davisyoshida/jax-gptq). I've been using the combination to run LLaMA finetunes on a single GPU.\n", | |
"\n", | |
"They're both applicable to basically any JAX function, which conveniently includes many HuggingFace models!\n", | |
"\n", | |
"The procedure is as follows:\n", | |
"\n", | |
"1. Quantize the weights of the model we want to use\n", | |
"2. Use Lorax to transform the original model function `F(params, inputs)` to one that takes a tuple of the original params and the low rank LoRA params: `F_lora(param_tuple, inputs)`\n", | |
"3. Wrap `F_lora` in `use_quantized` transform so that it knows how to handle arguments which are int8 matrices with two parameters per byte.\n", | |
"4. Train the model, updating only the low rank params and leaving the larger 4-bit model weights frozen.\n", | |
"\n", | |
"I'd love feedback on one or both of these tools so please let me know on their Githubs if you have any suggestions. JAX-GPTQ in particular is still in a really early state." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"####XLA Runtime OOM Prevention" | |
], | |
"metadata": { | |
"id": "SYw-sN1-eX3n" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"\n", | |
"# Allocate 90% of the GPU memory to the XLA runtime\n", | |
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\".9\"\n", | |
"\n", | |
"# Disable preallocation of memory\n", | |
"os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n", | |
"\n", | |
"# Use the platform allocator instead of the cuda allocator\n", | |
"os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"]=\"platform\"" | |
], | |
"metadata": { | |
"id": "3DPHwXufeYGC" | |
}, | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "0Y6JeyF45yd_" | |
}, | |
"source": [ | |
"### Setup" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true, | |
"id": "ljjNpQvkrhsA", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "473ec96b-283c-4e54-c846-6962e9a05ddf" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting git+https://github.com/davisyoshida/jax-gptq.git\n", | |
" Cloning https://github.com/davisyoshida/jax-gptq.git to /tmp/pip-req-build-ulx_pxtq\n", | |
" Running command git clone --filter=blob:none --quiet https://github.com/davisyoshida/jax-gptq.git /tmp/pip-req-build-ulx_pxtq\n", | |
" Resolved https://github.com/davisyoshida/jax-gptq.git to commit 8b8ff0fd23b4a7732f1c5dca98d7275045194d3c\n", | |
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
"Building wheels for collected packages: jax-gptq\n", | |
" Building wheel for jax-gptq (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for jax-gptq: filename=jax_gptq-0.0.1-py3-none-any.whl size=16385 sha256=886622fcb6c0ae1727f6f6d96653b215684c41a9dd99a243d2a15fb40377ef17\n", | |
" Stored in directory: /tmp/pip-ephem-wheel-cache-gtatuanf/wheels/ff/5e/fb/dec939c953c916b7437c0ce0839617a79dc06e0a2fd85138a2\n", | |
"Successfully built jax-gptq\n", | |
"Installing collected packages: jax-gptq\n", | |
"Successfully installed jax-gptq-0.0.1\n", | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting jax-lorax\n", | |
" Downloading jax_lorax-0.1.2-py3-none-any.whl (8.4 kB)\n", | |
"Requirement already satisfied: jax<0.5.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from jax-lorax) (0.4.10)\n", | |
"Requirement already satisfied: jaxlib<0.5.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from jax-lorax) (0.4.10+cuda11.cudnn86)\n", | |
"Requirement already satisfied: ml-dtypes>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (0.1.0)\n", | |
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.10/dist-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (1.22.4)\n", | |
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (3.3.0)\n", | |
"Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.10/dist-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (1.10.1)\n", | |
"Installing collected packages: jax-lorax\n", | |
"Successfully installed jax-lorax-0.1.2\n", | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting accelerate\n", | |
" Downloading accelerate-0.20.3-py3-none-any.whl (227 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m227.6/227.6 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.22.4)\n", | |
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.1)\n", | |
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", | |
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0)\n", | |
"Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)\n", | |
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (3.12.0)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (4.5.0)\n", | |
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (1.11.1)\n", | |
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (3.1)\n", | |
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (3.1.2)\n", | |
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (2.0.0)\n", | |
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate) (3.25.2)\n", | |
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate) (16.0.5)\n", | |
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.6.0->accelerate) (2.1.2)\n", | |
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.6.0->accelerate) (1.3.0)\n", | |
"Installing collected packages: accelerate\n", | |
"Successfully installed accelerate-0.20.3\n", | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (0.4.10)\n", | |
"Collecting jax\n", | |
" Downloading jax-0.4.12.tar.gz (1.3 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m46.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", | |
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", | |
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", | |
"Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (0.4.10+cuda11.cudnn86)\n", | |
"Collecting jaxlib\n", | |
" Downloading jaxlib-0.4.12-cp310-cp310-manylinux2014_x86_64.whl (71.4 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.4/71.4 MB\u001b[0m \u001b[31m57.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: ml-dtypes>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from jax) (0.1.0)\n", | |
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.10/dist-packages (from jax) (1.22.4)\n", | |
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax) (3.3.0)\n", | |
"Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.10/dist-packages (from jax) (1.10.1)\n", | |
"Building wheels for collected packages: jax\n", | |
" Building wheel for jax (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for jax: filename=jax-0.4.12-py3-none-any.whl size=1498447 sha256=69a1a5291f5e76bde6b76257dfa972cfee44386b666ef4100cbeb25f8f70229a\n", | |
" Stored in directory: /tmp/pip-ephem-wheel-cache-_6j78e9k/wheels/e8/48/6d/8fc5366c9f000bd18db799e801d5e41c6a7f55d73fd3038b7e\n", | |
"Successfully built jax\n", | |
"Installing collected packages: jaxlib, jax\n", | |
" Attempting uninstall: jaxlib\n", | |
" Found existing installation: jaxlib 0.4.10+cuda11.cudnn86\n", | |
" Uninstalling jaxlib-0.4.10+cuda11.cudnn86:\n", | |
" Successfully uninstalled jaxlib-0.4.10+cuda11.cudnn86\n", | |
" Attempting uninstall: jax\n", | |
" Found existing installation: jax 0.4.10\n", | |
" Uninstalling jax-0.4.10:\n", | |
" Successfully uninstalled jax-0.4.10\n", | |
"Successfully installed jax-0.4.12 jaxlib-0.4.12\n", | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"\u001b[31mERROR: Could not find a version that satisfies the requirement bitsandbytes-cuda117==0.26.0 (from versions: 0.26.0.post2)\u001b[0m\u001b[31m\n", | |
"\u001b[0m\u001b[31mERROR: No matching distribution found for bitsandbytes-cuda117==0.26.0\u001b[0m\u001b[31m\n", | |
"\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1\n", | |
" Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m66.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.12.0)\n", | |
"Collecting huggingface-hub<1.0,>=0.11.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.22.4)\n", | |
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (23.1)\n", | |
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (6.0)\n", | |
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2022.10.31)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.27.1)\n", | |
"Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m102.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.65.0)\n", | |
"Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (8.4.0)\n", | |
"Requirement already satisfied: librosa in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.10.0.post2)\n", | |
"Collecting pyctcdecode>=0.4.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)\n", | |
"Collecting phonemizer (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading phonemizer-3.2.1-py3-none-any.whl (90 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.6/90.6 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting kenlm (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading kenlm-0.1.tar.gz (424 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m425.0/425.0 kB\u001b[0m \u001b[31m36.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
"Collecting sentencepiece!=0.1.92,>=0.1.91 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m58.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting protobuf<=3.20.2 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading protobuf-3.20.2-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m44.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting decord==0.6.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading decord-0.6.0-py3-none-manylinux2010_x86_64.whl (13.6 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.6/13.6 MB\u001b[0m \u001b[31m91.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting av==9.2.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading av-9.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (28.8 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m28.8/28.8 MB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting onnxconverter-common (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading onnxconverter_common-1.13.0-py2.py3-none-any.whl (83 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.8/83.8 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting tf2onnx (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading tf2onnx-1.14.0-py3-none-any.whl (451 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m451.2/451.2 kB\u001b[0m \u001b[31m36.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting onnxruntime>=1.4.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading onnxruntime-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.9 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.9/5.9 MB\u001b[0m \u001b[31m100.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting onnxruntime-tools>=1.4.2 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading onnxruntime_tools-1.7.0-py3-none-any.whl (212 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.7/212.7 kB\u001b[0m \u001b[31m24.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting ftfy (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting deepspeed>=0.8.3 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading deepspeed-0.9.4.tar.gz (808 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m808.8/808.8 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
"Requirement already satisfied: accelerate>=0.10.0 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.20.3)\n", | |
"Collecting timm (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m95.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (5.9.5)\n", | |
"Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.0.1+cu118)\n", | |
"Collecting hjson (from deepspeed>=0.8.3->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading hjson-3.1.0-py3-none-any.whl (54 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.0/54.0 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting ninja (from deepspeed>=0.8.3->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading ninja-1.11.1-py2.py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (145 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m146.0/146.0 kB\u001b[0m \u001b[31m17.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: py-cpuinfo in /usr/local/lib/python3.10/dist-packages (from deepspeed>=0.8.3->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (9.0.0)\n", | |
"Requirement already satisfied: pydantic<2.0.0 in /usr/local/lib/python3.10/dist-packages (from deepspeed>=0.8.3->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.10.7)\n", | |
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2023.4.0)\n", | |
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.5.0)\n", | |
"Collecting coloredlogs (from onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (23.3.3)\n", | |
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.11.1)\n", | |
"Collecting onnx (from onnxruntime-tools>=1.4.2->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m91.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting py3nvml (from onnxruntime-tools>=1.4.2->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading py3nvml-0.2.7-py3-none-any.whl (55 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.5/55.5 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting pygtrie<3.0,>=2.1 (from pyctcdecode>=0.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading pygtrie-2.5.0-py3-none-any.whl (25 kB)\n", | |
"Collecting hypothesis<7,>=6.14 (from pyctcdecode>=0.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading hypothesis-6.78.2-py3-none-any.whl (416 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m416.8/416.8 kB\u001b[0m \u001b[31m33.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: wcwidth>=0.2.5 in /usr/local/lib/python3.10/dist-packages (from ftfy->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.2.6)\n", | |
"Requirement already satisfied: audioread>=2.1.9 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.0.0)\n", | |
"Requirement already satisfied: scipy>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.10.1)\n", | |
"Requirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.2.2)\n", | |
"Requirement already satisfied: joblib>=0.14 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.2.0)\n", | |
"Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.4.2)\n", | |
"Requirement already satisfied: numba>=0.51.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.56.4)\n", | |
"Requirement already satisfied: soundfile>=0.12.1 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.12.1)\n", | |
"Requirement already satisfied: pooch<1.7,>=1.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.6.0)\n", | |
"Requirement already satisfied: soxr>=0.3.2 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.3.5)\n", | |
"Requirement already satisfied: lazy-loader>=0.1 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.2)\n", | |
"Requirement already satisfied: msgpack>=1.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.0.5)\n", | |
"Collecting segments (from phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading segments-2.2.1-py2.py3-none-any.whl (15 kB)\n", | |
"Requirement already satisfied: attrs>=18.1 in /usr/local/lib/python3.10/dist-packages (from phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (23.1.0)\n", | |
"Collecting dlinfo (from phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading dlinfo-1.2.1-py3-none-any.whl (3.6 kB)\n", | |
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.26.15)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2022.12.7)\n", | |
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.0.12)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.4)\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from tf2onnx->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.16.0)\n", | |
"Collecting flatbuffers (from onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading flatbuffers-2.0.7-py2.py3-none-any.whl (26 kB)\n", | |
"Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from timm->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.15.2+cu118)\n", | |
"Collecting safetensors (from timm->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m76.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: sortedcontainers<3.0.0,>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from hypothesis<7,>=6.14->pyctcdecode>=0.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.4.0)\n", | |
"Requirement already satisfied: exceptiongroup>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from hypothesis<7,>=6.14->pyctcdecode>=0.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.1.1)\n", | |
"Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba>=0.51.0->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.39.1)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from numba>=0.51.0->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (67.7.2)\n", | |
"Requirement already satisfied: appdirs>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from pooch<1.7,>=1.0->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.4.4)\n", | |
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.20.0->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.1.0)\n", | |
"Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.10/dist-packages (from soundfile>=0.12.1->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.15.1)\n", | |
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.1)\n", | |
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.1.2)\n", | |
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.0.0)\n", | |
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.25.2)\n", | |
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (16.0.5)\n", | |
"Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting xmltodict (from py3nvml->onnxruntime-tools>=1.4.2->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading xmltodict-0.13.0-py2.py3-none-any.whl (10.0 kB)\n", | |
"Collecting clldutils>=1.7.3 (from segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading clldutils-3.19.0-py2.py3-none-any.whl (1.7 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m53.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting csvw>=1.5.6 (from segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading csvw-3.1.3-py2.py3-none-any.whl (56 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.7/56.7 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.3.0)\n", | |
"Requirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi>=1.0->soundfile>=0.12.1->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.21)\n", | |
"Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.8.2)\n", | |
"Requirement already satisfied: tabulate>=0.7.7 in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.8.10)\n", | |
"Collecting colorlog (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)\n", | |
"Collecting pylatexenc (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading pylatexenc-2.10.tar.gz (162 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m162.6/162.6 kB\u001b[0m \u001b[31m17.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
"Requirement already satisfied: markdown in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.4.3)\n", | |
"Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.9.2)\n", | |
"Requirement already satisfied: markupsafe in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.1.2)\n", | |
"Requirement already satisfied: babel in /usr/local/lib/python3.10/dist-packages (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.12.1)\n", | |
"Collecting colorama (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n", | |
"Collecting isodate (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading isodate-0.6.1-py2.py3-none-any.whl (41 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.7/41.7 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.3.3)\n", | |
"Collecting language-tags (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading language_tags-1.2.0-py3-none-any.whl (213 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m213.4/213.4 kB\u001b[0m \u001b[31m21.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting rdflib (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading rdflib-6.3.2-py3-none-any.whl (528 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m528.1/528.1 kB\u001b[0m \u001b[31m30.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting rfc3986<2 (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n", | |
" Downloading rfc3986-1.5.0-py2.py3-none-any.whl (31 kB)\n", | |
"Requirement already satisfied: uritemplate>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.1.1)\n", | |
"Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.19.3)\n", | |
"Requirement already satisfied: pyparsing<4,>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from rdflib->csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.0.9)\n", | |
"Building wheels for collected packages: deepspeed, kenlm, pylatexenc\n", | |
" Building wheel for deepspeed (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for deepspeed: filename=deepspeed-0.9.4-py3-none-any.whl size=843978 sha256=d1b1f1ff90e1ccdc20f2f3931eff74d83b341f45a48572cb038023d19c875ab1\n", | |
" Stored in directory: /root/.cache/pip/wheels/2d/ae/38/1d1c49ac8687c5808b3732e3541b6c896459fb8404763eb98b\n", | |
" Building wheel for kenlm (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for kenlm: filename=kenlm-0.1-cp310-cp310-linux_x86_64.whl size=3003959 sha256=ac0503b6460b64d96d1a828bb322b5ea6ef3f3825947743e5f9fc1f9ed896c9b\n", | |
" Stored in directory: /root/.cache/pip/wheels/4e/3a/01/9105a071c30781823efbd96a58279c16f948a87cafb1144042\n", | |
" Building wheel for pylatexenc (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for pylatexenc: filename=pylatexenc-2.10-py3-none-any.whl size=136820 sha256=c46a5f2915e96aaa8aa771710b5ac9f24cd18e381034243353393cb92352f0d8\n", | |
" Stored in directory: /root/.cache/pip/wheels/d3/31/8b/e09b0386afd80cfc556c00408c9aeea5c35c4d484a9c762fd5\n", | |
"Successfully built deepspeed kenlm pylatexenc\n", | |
"Installing collected packages: tokenizers, sentencepiece, safetensors, rfc3986, pylatexenc, pygtrie, ninja, language-tags, kenlm, hjson, flatbuffers, dlinfo, av, xmltodict, protobuf, isodate, hypothesis, humanfriendly, ftfy, decord, colorlog, colorama, rdflib, pyctcdecode, py3nvml, onnx, huggingface-hub, coloredlogs, clldutils, transformers, tf2onnx, onnxruntime-tools, onnxruntime, onnxconverter-common, csvw, segments, phonemizer, timm, deepspeed\n", | |
" Attempting uninstall: flatbuffers\n", | |
" Found existing installation: flatbuffers 23.3.3\n", | |
" Uninstalling flatbuffers-23.3.3:\n", | |
" Successfully uninstalled flatbuffers-23.3.3\n", | |
" Attempting uninstall: protobuf\n", | |
" Found existing installation: protobuf 3.20.3\n", | |
" Uninstalling protobuf-3.20.3:\n", | |
" Successfully uninstalled protobuf-3.20.3\n", | |
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", | |
"tensorflow 2.12.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 3.20.2 which is incompatible.\n", | |
"tensorflow-metadata 1.13.1 requires protobuf<5,>=3.20.3, but you have protobuf 3.20.2 which is incompatible.\u001b[0m\u001b[31m\n", | |
"\u001b[0mSuccessfully installed av-9.2.0 clldutils-3.19.0 colorama-0.4.6 coloredlogs-15.0.1 colorlog-6.7.0 csvw-3.1.3 decord-0.6.0 deepspeed-0.9.4 dlinfo-1.2.1 flatbuffers-2.0.7 ftfy-6.1.1 hjson-3.1.0 huggingface-hub-0.15.1 humanfriendly-10.0 hypothesis-6.78.2 isodate-0.6.1 kenlm-0.1 language-tags-1.2.0 ninja-1.11.1 onnx-1.14.0 onnxconverter-common-1.13.0 onnxruntime-1.15.0 onnxruntime-tools-1.7.0 phonemizer-3.2.1 protobuf-3.20.2 py3nvml-0.2.7 pyctcdecode-0.5.0 pygtrie-2.5.0 pylatexenc-2.10 rdflib-6.3.2 rfc3986-1.5.0 safetensors-0.3.1 segments-2.2.1 sentencepiece-0.1.99 tf2onnx-1.14.0 timm-0.9.2 tokenizers-0.13.3 transformers-4.28.1 xmltodict-0.13.0\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.colab-display-data+json": { | |
"pip_warning": { | |
"packages": [ | |
"google" | |
] | |
} | |
} | |
}, | |
"metadata": {} | |
} | |
], | |
"source": [ | |
"!pip install --upgrade --no-cache-dir git+https://github.com/davisyoshida/jax-gptq.git\n", | |
"!pip install --upgrade --no-cache-dir jax-lorax\n", | |
"#!pip install --upgrade --no-cache-dir transformers\n", | |
"#!pip install --upgrade --no-cache-dir bitsandbytes-cuda110 bitsandbytes\n", | |
"!pip install --upgrade --no-cache-dir accelerate\n", | |
"\n", | |
"!pip install --upgrade --no-cache-dir jax jaxlib\n", | |
"\n", | |
"#!pip uninstall --yes bitsandbytes-cuda110 bitsandbytes transformers\n", | |
"!pip install bitsandbytes-cuda117==0.26.0\n", | |
"!pip install transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"id": "75-T_R0Ms9qD", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "2c63e202-443a-4b65-8665-06ecbfe0cac2" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | |
] | |
} | |
], | |
"source": [ | |
"from functools import partial\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"import numpy as np\n", | |
"import optax\n", | |
"import torch\n", | |
"\n", | |
"import transformers\n", | |
"from transformers import (\n", | |
" CONFIG_MAPPING,\n", | |
" FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,\n", | |
" AutoConfig,\n", | |
" AutoTokenizer,\n", | |
" FlaxAutoModelForCausalLM,\n", | |
" HfArgumentParser,\n", | |
" TrainingArguments,\n", | |
" is_tensorboard_available,\n", | |
")\n", | |
"\n", | |
"from tqdm import trange\n", | |
"\n", | |
"import lorax\n", | |
"import jax_gptq\n", | |
"\n", | |
"#gpu = jax.devices('gpu')[0]\n", | |
"cpu = jax.devices('cpu')[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "GQuDSjz7svdL" | |
}, | |
"source": [ | |
"## Toy Example\n", | |
"\n", | |
"### Model/Data setup\n", | |
"\n", | |
"First we'll define an MLP and make some parameters for it:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from transformers import LongT5Config, FlaxT5ForConditionalGeneration\n", | |
"from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer\n", | |
"\n", | |
"from transformers import BitsAndBytesConfig\n", | |
"\n", | |
"\n", | |
"nf4_config = BitsAndBytesConfig(\n", | |
" load_in_4bit=True,\n", | |
" bnb_4bit_quant_type=\"nf4\",\n", | |
" bnb_4bit_use_double_quant=True,\n", | |
" bnb_4bit_compute_dtype=torch.bfloat16\n", | |
")\n", | |
"\n", | |
"# Load the LongT5-XL model with its configuration\n", | |
"model_id = \"google/long-t5-tglobal-xl\"\n", | |
"config = LongT5Config.from_pretrained(model_id)\n", | |
"#model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_4bit=True, device_map=\"auto\")\n", | |
"model = AutoModelForSeq2SeqLM.from_pretrained(model_id, quantization_config=nf4_config)\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_id)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 344, | |
"referenced_widgets": [ | |
"dda94429aac549438c343dd2edc24e0a", | |
"ee547e0a59eb422d975f02daaf672629", | |
"989350afa53349d39342297bfb4eca95", | |
"9954285510cb4ebab4deb9d6558dd5ed", | |
"663b81ccdf0c4b02af24b1cd81b7daa6", | |
"6bf5f6cd0df14197b8da4303b43cb655", | |
"693fb73b9dc0411eab41c5fede4b611c", | |
"db0e6b46120844dc8620fbd2a4076d0d", | |
"9872d010314e492f9a85cbb59edb8cc3", | |
"1478a80283ff43f193bd9a458a6c8274", | |
"199682514a0e43b5a13244a818722d90", | |
"cdb5e7259f57470296128e76fdf5e6ea", | |
"9b31610503d84238b9e7a2321ff1b10b", | |
"32d7477699eb461584db4072a7c97052", | |
"6c4eafec00674c838c628394c7c68faa", | |
"a63b31cd45b24408b5a9068d1ecd1b4b", | |
"715506a0e5694e7ca6061ee99bd92382", | |
"3c4dbdf570eb46b690cd500476020064", | |
"bdd5160d52de4f9895f8e4dd7f7ec20a", | |
"c4eee8d4242045bba3d38597dbb414b3", | |
"2c47cb97a18f4ee89bd945b1ca11347d", | |
"4a9f0ebd9d7b4192a25ef75250c03deb", | |
"7d0234ed0d35487e826565df214a2cc9", | |
"46ad2aad7e474c73be0bfb1fef55a4bd", | |
"6d811cfd512f4916b83640f9297bc1bd", | |
"8c416aa8a81447cba3527d287aeb2f35", | |
"8e073507368745088aa356b23f044f0b", | |
"902022c949e448eeb8ce3e31d1ef4153", | |
"8bcfb6b0b56d4046a839c0f9ae4c16b1", | |
"7de7565229de41149b770fe25714a847", | |
"52e158b04b8c4858929bf3739f7c78df", | |
"a12f0daf615547faa649107033a93011", | |
"3067249f4bc64398906fceea8d1e5934", | |
"5fe6035fe6e8426580d632bcedc3d19b", | |
"fe81c32260ca47c7bcd4fd6230ed3c81", | |
"ec4a1c28bbb54799a2278ee1e1d65383", | |
"68cc8181fd484af1b9f42b1c9b5f8d75", | |
"0d20b32672e34c7295634090e6cebf27", | |
"24f189322a134a20900084de1e08c376", | |
"aae1c621a5b94960ace82d52253ccaf7", | |
"ff2779f06f954e35a5eec9a0a02684f1", | |
"c00fb1fde3f74d82ac6b3b5bc19a9796", | |
"d50444fed9a746ff8281239c4c1a3058", | |
"1110d48d7a664208bcd57f3c67aecccd", | |
"c22c1289367046babeb0c9409c181f8b", | |
"eb3f2e923426469e9b92a51187f60f60", | |
"83784cb918974e4c8664f78561ea8cf8", | |
"008aa06de0fd4ec789d3d42e83dbd16e", | |
"33217d952d544284937852eeb17323fc", | |
"c83d30547b08413f90300abdf318b123", | |
"2b8df3f5200e4c9fa59f057c4586a700", | |
"8fc75f12da864868a4d0ab9c1949d1ac", | |
"b00196f8517c415ea5388bf1a72ffb37", | |
"09ae4873193f41baa133879a8d3f5905", | |
"bd9db9a757db4ccfa68154524ac95dbb", | |
"410478d1c6164c679702f1b959ec6dc6", | |
"34ae822850454b4b8f92dae185943fe4", | |
"44bd417858fb4982bd97cc07b3f03b5d", | |
"84ddc0bd3bfd4df9a300a0ea01dafc06", | |
"60e20079809641db96f76aa61980ac74", | |
"a82fed6b81274263a1c88e41063356f4", | |
"14c39b554476480d8ed17a4188ffc261", | |
"1b9abcbc8e054a069d9bf0710329928b", | |
"d85f0a6e82b648b0899e7006b5a7e740", | |
"df82eb9aaa354a4395a5304ae038a828", | |
"96b5cf9d5b7846bfa61523ee1ec4db69", | |
"3458c2feac6b41ebbf6ed872d9112df0", | |
"d8cf12f088a94b4da262d8d2c134e89f", | |
"2973bec2ccff49168ba84ddc1f326e8a", | |
"58230fe17b7840c983a59a807a051f2c", | |
"b4687344280545029c6505de0d277b97", | |
"910e96c5ff324caba366ede006d164ab", | |
"d8138fd266e44ab3963a274ef7018b0c", | |
"0043c749b6ad4d27a251abbf5d3644ed", | |
"ccf2b81c4a774692980b024885d00872", | |
"47302718457149c9bd98b2ef0b702d52", | |
"83d6be00be514667a08472ba48a7d42f", | |
"6d2432a98bf84a77a215f03e69a8a79f", | |
"9ecd6788ffb04d37ae547f518246fc6b", | |
"da31333b1c844b56830916b794ce9833", | |
"a90c9881a7d0420cb7e230f2b19f25f4", | |
"7fc0552e04b247f5a8b5c8dc7ebfbfe1", | |
"5d21f67dc7aa4297b7115299c2589da7", | |
"5562aa475ee846a89442b793da39c802", | |
"46fd98456a604139a948a97056fb6842", | |
"51bebf1ec8fe4e798bbba4bbf158bd6d", | |
"2f2b07d552364493965c7811042d73c5", | |
"fab5d066dc794d1dbb17b282081b1d92", | |
"d20b4ddee92c4d22a81bd819e8314fa3", | |
"556d33b982cc4989a03fc4b29bc7e84a", | |
"03675f31e76a4135a2ae239f3ee42655", | |
"8231b974ae924b68959d606e073d66ad", | |
"4bfd7fde14514f41bea4edcb878cfc3c", | |
"d439f8d48f754ed1976392cadfaac95f", | |
"86b9a19a4ecf4e52974e1dc84335aca2", | |
"1b6c1e021fa64948ae533a11f029bd50", | |
"71b903a4156149bbb714e1f3841f13b2", | |
"a011aa10bd374418b030f7cd2f602255", | |
"e52faa2810604688844ddf624b0a1400" | |
] | |
}, | |
"id": "YKcA0xmzRIas", | |
"outputId": "1d09422a-3405-40a3-c1e2-f3fed346f513" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading (…)lve/main/config.json: 0%| | 0.00/896 [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "dda94429aac549438c343dd2edc24e0a" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[2023-06-15 06:49:37,016] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading (…)model.bin.index.json: 0%| | 0.00/55.4k [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "cdb5e7259f57470296128e76fdf5e6ea" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "7d0234ed0d35487e826565df214a2cc9" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading (…)l-00001-of-00002.bin: 0%| | 0.00/9.45G [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "5fe6035fe6e8426580d632bcedc3d19b" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading (…)l-00002-of-00002.bin: 0%| | 0.00/1.95G [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "c22c1289367046babeb0c9409c181f8b" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:accelerate.utils.modeling:The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "410478d1c6164c679702f1b959ec6dc6" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading (…)neration_config.json: 0%| | 0.00/147 [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "3458c2feac6b41ebbf6ed872d9112df0" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading spiece.model: 0%| | 0.00/792k [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "6d2432a98bf84a77a215f03e69a8a79f" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"Downloading (…)/main/tokenizer.json: 0%| | 0.00/1.39M [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "d20b4ddee92c4d22a81bd819e8314fa3" | |
} | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from google.colab import drive\n", | |
"drive.mount('/content/drive')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ypIb3RHPM77-", | |
"outputId": "4fa2a63a-6f41-41ac-d416-0cd6979a29b2" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Mounted at /content/drive\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Save the quantized model to disc for future use in TPU\n", | |
"!mkdir -p /content/checkpoints\n", | |
"\n", | |
"model.save_pretrained(\n", | |
" \"/content/checkpoints\",\n", | |
" commit_message=f\"Saving weights and logs\",\n", | |
")" | |
], | |
"metadata": { | |
"id": "McT8hhNNlEXq" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!cp -rf /content/checkpoints/ /content/drive/MyDrive/" | |
], | |
"metadata": { | |
"id": "5qQ0iUToOQo8" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!zip -r longT5-xl-quantized.zip /content/checkpoints/" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "3nwB94DunQ3k", | |
"outputId": "3252d182-db7d-4926-8b97-59a9da041f14" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
" adding: content/checkpoints/ (stored 0%)\n", | |
" adding: content/checkpoints/pytorch_model.bin.index.json (deflated 96%)\n", | |
" adding: content/checkpoints/pytorch_model-00002-of-00002.bin (deflated 7%)\n", | |
" adding: content/checkpoints/config.json (deflated 48%)\n", | |
" adding: content/checkpoints/pytorch_model-00001-of-00002.bin (deflated 7%)\n", | |
" adding: content/checkpoints/generation_config.json (deflated 29%)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Reference : https://github.com/davisyoshida/lorax/blob/master/examples/huggingface_gpt2.py\n", | |
"\n", | |
"import warnings\n", | |
"\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"import optax\n", | |
"from transformers import FlaxGPT2LMHeadModel\n", | |
"\n", | |
"from lorax import simple_spec, init_lora, lora, LORA_FULL, merge_params\n", | |
"\n", | |
"def main():\n", | |
" #model = FlaxGPT2LMHeadModel.from_pretrained('gpt2')\n", | |
"\n", | |
" # Wrap the forward pass in so that lorax knows which params to LoRA-fy (it only does the first argument by default)\n", | |
" @lora\n", | |
" def lora_forward(params, input_ids):\n", | |
" return model(input_ids, params=params)\n", | |
"\n", | |
" # This function defines a spec which tells lorax how each parameter should be handled\n", | |
" def decision_fn(path, param):\n", | |
" if 'embedding' in path:\n", | |
" print(f'Fully finetuning param {path}')\n", | |
" return LORA_FULL\n", | |
" dim = 32\n", | |
" print(f'Using LoRA with dim={dim} for param {path}')\n", | |
" return dim\n", | |
"\n", | |
" # Create a pytree with the same shape as params indicating how each parameter should be handled\n", | |
" params, dropout_rng, *_ = model.parameters()\n", | |
"\n", | |
" # Convert the generator object to a list of arrays\n", | |
" params = list(params)\n", | |
"\n", | |
" # Convert the list of tensors to a tuple of tensors\n", | |
" params = tuple(params)\n", | |
" lora_spec = simple_spec(params, decision_fn=decision_fn, tune_vectors=True)\n", | |
"\n", | |
" # Split the parameters up into tunable and frozen ones, and initialize a pair of LoRA matrices for each parameter\n", | |
" # which had a spec value other than LORA_FULL or LORA_FREEZE\n", | |
" freeze_params, tune_params = init_lora(model.parameters(), lora_spec, jax.random.PRNGKey(0))\n", | |
"\n", | |
" optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)\n", | |
"\n", | |
" # Make sure to only pass the tunable parameters to the optimizer\n", | |
" opt_state = optimizer.init(tune_params)\n", | |
"\n", | |
" # The loss function should take the tunable and frozen params separately so\n", | |
" # you can differentiate w.r.t. the tunable ones only\n", | |
" def loss_fn(tunable_params, frozen_params, batch):\n", | |
" input_ids = batch[:, :-1]\n", | |
" logits = lora_forward((frozen_params, tunable_params), input_ids).logits\n", | |
"\n", | |
" logprobs = jax.nn.log_softmax(logits)\n", | |
" target_logprobs = jnp.take_along_axis(logprobs, batch[:, 1:, None], axis=-1)\n", | |
" return -jnp.mean(target_logprobs)\n", | |
"\n", | |
" @jax.jit\n", | |
" def update_fn(tunable_params, frozen_params, opt_state, batch):\n", | |
" loss, grads = jax.value_and_grad(loss_fn)(tunable_params, frozen_params, batch)\n", | |
" updates, new_opt_state = optimizer.update(grads, opt_state, params=tunable_params)\n", | |
"\n", | |
" new_tunable_params = optax.apply_updates(tunable_params, updates)\n", | |
" return new_tunable_params, new_opt_state, loss\n", | |
"\n", | |
" # Train on a dummy batch to demo loss going down\n", | |
" example_data = jax.random.randint(jax.random.PRNGKey(0), (4, 128), 0, 50257)\n", | |
" for _ in range(100):\n", | |
" tune_params, opt_state, loss = update_fn(tune_params, freeze_params, opt_state, example_data)\n", | |
" print(loss)\n", | |
"\n", | |
" final_predictions = lora_forward((freeze_params, tune_params), example_data).logits\n", | |
" merged_params = merge_params(freeze_params, tune_params)\n", | |
"\n", | |
" orig_model_predictions = model(example_data, params=merged_params).logits\n", | |
"\n", | |
" gap = jnp.max(jnp.abs(final_predictions - orig_model_predictions))\n", | |
" print(f'Max prediction gap: {gap:.3e}')\n", | |
"\n", | |
"if __name__ == '__main__':\n", | |
" main()" | |
], | |
"metadata": { | |
"id": "R4arRIjcL_F2", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 788 | |
}, | |
"outputId": "95b74d2b-b819-441a-9f95-9f55b1843e5d" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", | |
"\u001b[31m│\u001b[0m in \u001b[92m<cell line: 80>\u001b[0m:\u001b[94m81\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m in \u001b[92mmain\u001b[0m:\u001b[94m41\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/lorax/\u001b[0m\u001b[1;33mhelpers.py\u001b[0m:\u001b[94m42\u001b[0m in \u001b[92minit_lora\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m39 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m40 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m ( \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m41 \u001b[0m\u001b[2m│ │ \u001b[0mjax.tree_map(freeze_getter, param_tree, spec, is_leaf=is_leaf), \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m42 \u001b[2m│ │ \u001b[0mjax.tree_util.tree_map_with_path(tune_getter, param_tree, spec, is_leaf=is_leaf) \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m43 \u001b[0m\u001b[2m│ \u001b[0m) \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m44 \u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m45 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92msimple_spec\u001b[0m(params, decision_fn=\u001b[94mNone\u001b[0m, tune_vectors=\u001b[94mFalse\u001b[0m, is_leaf=\u001b[94mNone\u001b[0m): \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/jax/_src/\u001b[0m\u001b[1;33mtree_util.py\u001b[0m:\u001b[94m788\u001b[0m in \u001b[92mtree_map_with_path\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m785 \u001b[0m\u001b[2m \u001b[0mkeypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m786 \u001b[0m\u001b[2m \u001b[0mkeypath_leaves = \u001b[96mlist\u001b[0m(\u001b[96mzip\u001b[0m(*keypath_leaves)) \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m787 \u001b[0m\u001b[2m \u001b[0mall_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) \u001b[94mfor\u001b[0m r \u001b[95min\u001b[0m rest] \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m788 \u001b[2m \u001b[0m\u001b[94mreturn\u001b[0m treedef.unflatten(f(*xs) \u001b[94mfor\u001b[0m xs \u001b[95min\u001b[0m \u001b[96mzip\u001b[0m(*all_keypath_leaves)) \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m789 \u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m790 \u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m791 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92m_child_keys\u001b[0m(pytree: Any) -> KeyPath: \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/jax/_src/\u001b[0m\u001b[1;33mtree_util.py\u001b[0m:\u001b[94m788\u001b[0m in \u001b[92m<genexpr>\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m785 \u001b[0m\u001b[2m \u001b[0mkeypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m786 \u001b[0m\u001b[2m \u001b[0mkeypath_leaves = \u001b[96mlist\u001b[0m(\u001b[96mzip\u001b[0m(*keypath_leaves)) \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m787 \u001b[0m\u001b[2m \u001b[0mall_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) \u001b[94mfor\u001b[0m r \u001b[95min\u001b[0m rest] \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m788 \u001b[2m \u001b[0m\u001b[94mreturn\u001b[0m treedef.unflatten(f(*xs) \u001b[94mfor\u001b[0m xs \u001b[95min\u001b[0m \u001b[96mzip\u001b[0m(*all_keypath_leaves)) \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m789 \u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m790 \u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m791 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92m_child_keys\u001b[0m(pytree: Any) -> KeyPath: \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/lorax/\u001b[0m\u001b[1;33mhelpers.py\u001b[0m:\u001b[94m20\u001b[0m in \u001b[92mtune_getter\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m17 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m spec_val == LORA_FULL: \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m18 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m param \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m19 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m20 \u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mlen\u001b[0m(param.shape) == \u001b[94m1\u001b[0m: \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m21 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mValueError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m'\u001b[0m\u001b[33mVectors must either be frozen or fully tuned, but got spe\u001b[0m \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m22 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mlen\u001b[0m(param.shape) == \u001b[94m2\u001b[0m: \u001b[31m│\u001b[0m\n", | |
"\u001b[31m│\u001b[0m \u001b[2m23 \u001b[0m\u001b[2m│ │ │ \u001b[0mb_dim, a_dim = param.shape \u001b[31m│\u001b[0m\n", | |
"\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", | |
"\u001b[1;91mAttributeError: \u001b[0m\u001b[32m'generator'\u001b[0m object has no attribute \u001b[32m'shape'\u001b[0m\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000\">╭─────────────────────────────── </span><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Traceback </span><span style=\"color: #bf7f7f; text-decoration-color: #bf7f7f; font-weight: bold\">(most recent call last)</span><span style=\"color: #800000; text-decoration-color: #800000\"> ────────────────────────────────╮</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\"><cell line: 80></span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">81</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">main</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">41</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/usr/local/lib/python3.10/dist-packages/lorax/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">helpers.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">42</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">init_lora</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">39 │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">40 │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> ( <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">41 │ │ </span>jax.tree_map(freeze_getter, param_tree, spec, is_leaf=is_leaf), <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>42 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span>jax.tree_util.tree_map_with_path(tune_getter, param_tree, spec, is_leaf=is_leaf) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">43 │ </span>) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">44 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">45 </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">simple_spec</span>(params, decision_fn=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">None</span>, tune_vectors=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">False</span>, is_leaf=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">None</span>): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/usr/local/lib/python3.10/dist-packages/jax/_src/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">tree_util.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">788</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">tree_map_with_path</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">785 </span>keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">786 </span>keypath_leaves = <span style=\"color: #00ffff; text-decoration-color: #00ffff\">list</span>(<span style=\"color: #00ffff; text-decoration-color: #00ffff\">zip</span>(*keypath_leaves)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">787 </span>all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) <span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> r <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> rest] <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>788 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> treedef.unflatten(f(*xs) <span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> xs <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">zip</span>(*all_keypath_leaves)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">789 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">790 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">791 </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">_child_keys</span>(pytree: Any) -> KeyPath: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/usr/local/lib/python3.10/dist-packages/jax/_src/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">tree_util.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">788</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\"><genexpr></span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">785 </span>keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">786 </span>keypath_leaves = <span style=\"color: #00ffff; text-decoration-color: #00ffff\">list</span>(<span style=\"color: #00ffff; text-decoration-color: #00ffff\">zip</span>(*keypath_leaves)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">787 </span>all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) <span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> r <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> rest] <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>788 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> treedef.unflatten(f(*xs) <span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> xs <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">zip</span>(*all_keypath_leaves)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">789 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">790 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">791 </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">_child_keys</span>(pytree: Any) -> KeyPath: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/usr/local/lib/python3.10/dist-packages/lorax/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">helpers.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">20</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">tune_getter</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">17 │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> spec_val == LORA_FULL: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">18 │ │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> param <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">19 │ │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>20 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">len</span>(param.shape) == <span style=\"color: #0000ff; text-decoration-color: #0000ff\">1</span>: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">21 │ │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">raise</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">ValueError</span>(<span style=\"color: #808000; text-decoration-color: #808000\">f'Vectors must either be frozen or fully tuned, but got spe</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">22 │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">len</span>(param.shape) == <span style=\"color: #0000ff; text-decoration-color: #0000ff\">2</span>: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">23 │ │ │ </span>b_dim, a_dim = param.shape <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n", | |
"<span style=\"color: #800000; text-decoration-color: #800000\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n", | |
"<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">AttributeError: </span><span style=\"color: #008000; text-decoration-color: #008000\">'generator'</span> object has no attribute <span style=\"color: #008000; text-decoration-color: #008000\">'shape'</span>\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "Djyo_reAs26R" | |
}, | |
"outputs": [], | |
"source": [ | |
"'''\n", | |
"# Initialize the model parameters using JAX's PRNG key\n", | |
"rng_key = jax.random.PRNGKey(0)\n", | |
"input_ids = jnp.array([[1, 2, 3, 4, 5]])\n", | |
"decoder_input_ids = jnp.array([[1, 2, 3, 4, 5]])\n", | |
"params = model.parameters()\n", | |
"'''\n", | |
"\n", | |
"# Modify my_model to use the LongT5-XL model instead of the custom model defined earlier\n", | |
"def my_model(params, x):\n", | |
" logits = model(input_ids=x, params=params, train=True).logits\n", | |
" return jnp.mean(logits)\n", | |
"\n", | |
"# Define a loss function for the LongT5-XL model\n", | |
"@jax.jit\n", | |
"def compute_loss(params, input_ids, decoder_input_ids, labels):\n", | |
" logits = model(\n", | |
" input_ids=input_ids,\n", | |
" decoder_input_ids=decoder_input_ids,\n", | |
" params=params,\n", | |
" train=True\n", | |
" ).logits\n", | |
"\n", | |
"# Transform the loss function to get the gradients\n", | |
"grad_fn = jax.value_and_grad(compute_loss)\n", | |
"\n", | |
"# Define an optimizer to update the parameters using the gradients\n", | |
"optimizer = optax.adam(learning_rate=1e-3)\n", | |
"\n", | |
"# Define a train step function which combines the loss function and optimizer update, does the forward and backward pass, and returns the updated parameters\n", | |
"@jax.jit\n", | |
"def train_step(params, x, y, optimizer):\n", | |
" grads, loss = grad_fn(params, x, y)\n", | |
" updates, optimizer_state = optimizer.update(grads, optimizer_state)\n", | |
" new_params = optax.apply_updates(params, updates)\n", | |
" return new_params, loss, optimizer_state\n", | |
"\n", | |
"# Define a batch generator function using get_batches() from stackoverflow.com\n", | |
"def generate_batch(batch_size, rng, DIM=512):\n", | |
" # Generate a batch of input-output pairs\n", | |
" X_batch = jax.random.normal(rng, (batch_size, DIM))\n", | |
" Y_batch = jax.random.randint(rng, (batch_size,), 0, 2, dtype=jnp.int32)\n", | |
"\n", | |
" return X_batch, Y_batch\n", | |
"\n", | |
"# Initialize the optimizer state and the PRNG key\n", | |
"optimizer_state = optimizer.init(params)\n", | |
"rng = jax.random.PRNGKey(0)\n", | |
"\n", | |
"# Train the model\n", | |
"num_steps = 50\n", | |
"batch_size = 4\n", | |
"\n", | |
"for i in range(num_steps):\n", | |
" # Generate a batch of input-output pairs\n", | |
" x_batch, y_batch = generate_batch(batch_size, rng)\n", | |
"\n", | |
" # Update the parameters and optimizer state\n", | |
" params, loss, optimizer_state = train_step(params, x_batch, y_batch, optimizer_state)\n", | |
"\n", | |
" # Print the loss every 10 steps\n", | |
" if i % 10 == 0:\n", | |
" print(f'Step {i}, Loss: {loss}')\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "RlCLAmjBvhnA" | |
}, | |
"source": [ | |
"GPT-Q needs input data for quantization. For an actual model we'd use real data but here we'll just make some random inputs." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "6govTMOZvgSC" | |
}, | |
"outputs": [], | |
"source": [ | |
"quant_data = [jax.random.normal(key, (batch_size, DIM)) for key in jax.random.split(data_key, 64)]\n", | |
"\n", | |
"# We'll save an output for later comparison since the quantization process will delete the original params\n", | |
"original_output = my_model(params, quant_data[0])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Rjdb3h46vtsi" | |
}, | |
"source": [ | |
"### Run GPT-Q to get the quantized weights\n", | |
"That's all for the setup, we can now just run GPT-Q (without any changes to the original model code):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"id": "L1Mw9ZLpvrLa" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Note that this may free the buffers associated with some or all of the parameters and the data to save VRAM\n", | |
"# I'd also recommend you put the params on the CPU, since `quantize()` will move the params to th GPU when necessary\n", | |
"quantized_params = jax_gptq.quantize(my_model, params, quant_data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "2NhVv8egwDQu" | |
}, | |
"source": [ | |
"The matrices have been quantized but the biases have been left alone:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "bWwXzTJyubbH" | |
}, | |
"outputs": [], | |
"source": [ | |
" print(f'W type: {type(quantized_params[0][\"w\"])}')\n", | |
" print(f'B type: {type(quantized_params[0][\"b\"])}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "QwYLTr6WwapB" | |
}, | |
"source": [ | |
"**Note**: The quantization procedure depends on the parameter being used in a matrix multiplication. Currently JAX-GPTQ supports general dot operations (including ones using tensors with any number of dimensions larger than 1), and convolutions with kernels of spatial size 1.\n", | |
"\n", | |
"### Applying the quantized weights\n", | |
"We can now run the quantized model without any code changes. All that's necessary is using `jax_gptq.use_quantized` to transform the function so it knows how to handle `QuantizedMatrix` values." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "I6aLdXqawQFs" | |
}, | |
"outputs": [], | |
"source": [ | |
"quantized_params = jax.device_put(quantized_params, gpu) # Move the params to the GPU\n", | |
"\n", | |
"# Originally:\n", | |
"# my_model(params, inputs)\n", | |
"# After:\n", | |
"# jax_gptq(my_model)(params, inputs)\n", | |
"quant_output = jax_gptq.use_quantized(my_model)(quantized_params, quant_data[0])\n", | |
"\n", | |
"print(f'Output of quantized network: {quant_output:.3e}')\n", | |
"print(f'Original output: {original_output:.3e}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "1vXkTTctx7Vo" | |
}, | |
"source": [ | |
"### Train with LoRA\n", | |
"\n", | |
"Now that we've compressed our model to 4-bits (and change) per parameter, we can add full precision LoRA parameters for finetuning.\n", | |
"\n", | |
"The one gotcha about combining the two is that Lorax doesn't know that QuantizedMatrix values are pytree leaves, so you need to give the Lorax functions an `is_leaf` predicate." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "l95MirHdzNo9" | |
}, | |
"source": [ | |
"**Initialization:** The `init_lora` function expects a pytree describing which parameters should get LoRA parameters, which should be fully trained, and which should be left frozen. `lorax.simple_spec` is a helper function for making these specs." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "HKkhcjx9zJy6" | |
}, | |
"outputs": [], | |
"source": [ | |
"def is_leaf(x):\n", | |
" return isinstance(x, jax_gptq.QuantizedMatrix)\n", | |
"\n", | |
"lora_spec = lorax.simple_spec(\n", | |
" params=quantized_params,\n", | |
" decision_fn=lambda pytree_path, arr: 4, # Just ignore the inputs and specify an inner rank of 4 for all params\n", | |
" tune_vectors=False, # Tell Lorax to put all the biases in the frozen params tree instead of the tunable params tree\n", | |
" is_leaf=is_leaf\n", | |
")\n", | |
"\n", | |
"# Lorax splits the parameters into two pytrees:\n", | |
"# freeze_params: Anything which received the value lorax.LORA_FREEZE in the spec\n", | |
"# train_params: Pairs of two narrow matrices for values which got positive integers as spec values, or the full parameter if the value lorax.LORA_FULL was in the spec\n", | |
"freeze_params, train_params = lorax.init_lora(quantized_params, lora_spec, jax.random.PRNGKey(1234), is_leaf=is_leaf)\n", | |
"\n", | |
"def merge_quantized_with_lora(q_params, lora_freeze):\n", | |
" return jax.tree_map(\n", | |
" lambda quant, from_lora: quant if isinstance(quant, jax_gptq.QuantizedMatrix) else from_lora,\n", | |
" q_params,\n", | |
" lora_freeze,\n", | |
" is_leaf=lambda x: isinstance(x, jax_gptq.QuantizedMatrix) # Tell tree_map to treat QuantizedMatrix as a single value instead of a non-leaf node\n", | |
" )\n", | |
"# Now we put the actual quantized params back\n", | |
"#freeze_params = merge_quantized_with_lora(quantized_params, freeze_params)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "-ebT9GXp16v4" | |
}, | |
"source": [ | |
"The `lorax.lora` transform converts a function from expecting a single pytree in the specified argument to expecting a tuple of two pytrees. It composes with other JAX transforms such as `jax_gptq.use_quantized`, so we can use both at once with no modifications to our model code." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "1XjjuQcq1oSq" | |
}, | |
"outputs": [], | |
"source": [ | |
"combined_params = (freeze_params, train_params)\n", | |
"\n", | |
"my_model_with_lora_and_quantized_weights = jax_gptq.use_quantized(lorax.lora(my_model))\n", | |
"\n", | |
"# The differences from the original `my_model` function are:\n", | |
"# 1. The params argument now expects a tuple of (frozen_params, trainable_params)\n", | |
"# 2. It knows how to compute with quantized weights\n", | |
"quantized_plus_lorax_output = my_model_with_lora_and_quantized_weights(combined_params, quant_data[0])\n", | |
"\n", | |
"print(f'GPTQ + Lorax output: {quantized_plus_lorax_output:.3e}')\n", | |
"print(f'GPTQ only: {quant_output:.3e}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "aIywP5qQ3KEH" | |
}, | |
"source": [ | |
"The above values are identical since LoRA initializes one of each pair of matrices as zeros.\n", | |
"\n", | |
"Let's look at the size of each pytree:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "nqQwBPjh2ttl" | |
}, | |
"outputs": [], | |
"source": [ | |
"count_params = partial(jax.tree_util.tree_reduce,\n", | |
" lambda acc, param: acc + (param.size if isinstance(param, jnp.ndarray) else 0),\n", | |
" initializer=0\n", | |
")\n", | |
"\n", | |
"print(f'{count_params(freeze_params):.3e} frozen params')\n", | |
"print(f'{count_params(train_params):.3e} trainable params')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "0CJ58F005g-c" | |
}, | |
"source": [ | |
"Training with this function is no different from any other JAX function, just make sure to only differentiate your loss with respect to the trainable parameters only. (See the next section for an example)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "m_lDOLnw5zoC" | |
}, | |
"source": [ | |
"## GPT-Q-ing + LoRA-ing HuggingFace's Flax GPT-2\n", | |
"I developed these transforms for use with my Haiku models, but since all JAX models are pure functions at the end of the day, it shouldn't matter what framework you use. Lorax supports matmuls and other matmul-like operations such as embedding lookups and 1-D convs.\n", | |
"\n", | |
"This is a minimal example of applying the combination to `gpt2-medium`, but it's basically model agnostic.\n", | |
"\n", | |
"First let's get the model:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "czS5kDWO6XTv" | |
}, | |
"outputs": [], | |
"source": [ | |
"from transformers import AutoTokenizer, FlaxAutoModelForCausalLM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "VnfmpQ6f6Yal" | |
}, | |
"outputs": [], | |
"source": [ | |
"model_name = 'gpt2-medium'\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | |
"model, params = FlaxAutoModelForCausalLM.from_pretrained(model_name, _do_init=False)\n", | |
"params = jax.device_put(params, cpu)\n", | |
"\n", | |
"# Because the embedding table is reused as the output linear layer, it'll get quantized at the end of the process, but that will seriously screw up the embedding lookup step, so we'll just save it for later here\n", | |
"orig_embedding_table = np.asarray(params['transformer']['wte']['embedding'])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "evCyWa787m_N" | |
}, | |
"source": [ | |
"The GPT-Q paper used real text data for quantization, but for this demo I'll just generate some random values." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ao_vTWAf7Tw-" | |
}, | |
"outputs": [], | |
"source": [ | |
"QUANT_BATCH_SIZE = 4\n", | |
"QUANT_EXAMPLE_LENGTH = 64 # I'd recommend making this bigger, but needs to be small to not crash colab\n", | |
"\n", | |
"quantization_data = []\n", | |
"key = jax.random.PRNGKey(0)\n", | |
"for _ in range(32):\n", | |
" batch = jax.random.randint(key, (QUANT_BATCH_SIZE, QUANT_EXAMPLE_LENGTH), 0, 50256)\n", | |
" quantization_data.append(batch)\n", | |
" key, = jax.random.split(key, 1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "0x_pT_fT8Co8" | |
}, | |
"source": [ | |
"HuggingFace's models don't have quite the right call signature, so we'll make a wrapper which takes (params, inputs) as an argument:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"id": "yddz4OUN8Bvt" | |
}, | |
"outputs": [], | |
"source": [ | |
"def apply_model(params, batch):\n", | |
" return model(batch, params=params)\n", | |
"\n", | |
"quantized_params = jax_gptq.quantize(apply_model, params, quantization_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ehblO3I98akJ" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Replace the quantized embedding table with the original one\n", | |
"quantized_params['transformer']['wte']['embedding'] = jnp.asarray(orig_embedding_table)\n", | |
"quantized_params = jax.device_put(quantized_params, gpu)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "WYiCG5fE9yKT" | |
}, | |
"source": [ | |
"### Finetuning GPT-2 with Lorax\n", | |
"\n", | |
"Same as [above](https://colab.research.google.com/drive/18rkULbWqk7mNZDx7Scx-JS3p_s45mgok#scrollTo=HKkhcjx9zJy6&line=3&uniqifier=1), we get the original param structure to tell Lorax how to initialize the LoRA params, then merge the quantized params back in after." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "FKS_dfll93sO" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Get pre-quantization param tree (some nodes will just be abstract values)\n", | |
"orig_params_or_shapes = jax_gptq.utils.quantized_params_to_shaped_arrays(quantized_params)\n", | |
"\n", | |
"# Tell Lorax which leaves should be frozen/fully trained/LoRA trained\n", | |
"spec = lorax.simple_spec(\n", | |
" orig_params_or_shapes,\n", | |
" lambda path, arr: 16 if any(pattern in path for pattern in ['c_attn', 'mlp']) else lorax.LORA_FREEZE,\n", | |
" tune_vectors=True\n", | |
")\n", | |
"\n", | |
"# Initialize parameters\n", | |
"key, init_key = jax.random.split(key)\n", | |
"freeze_params, train_params = lorax.init_lora(\n", | |
" orig_params_or_shapes,\n", | |
" spec,\n", | |
" init_key\n", | |
")\n", | |
"\n", | |
"# Put the quantized params back into the frozen param tree\n", | |
"freeze_params = merge_quantized_with_lora(quantized_params, freeze_params)\n", | |
"combined_params = freeze_params, train_params" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "T8bJwqN2Bfqh" | |
}, | |
"source": [ | |
"Now we can just transform the `apply_model` function and it will use both LoRA and 4-bit quantized parameters" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "glARn7Z0BX4g" | |
}, | |
"outputs": [], | |
"source": [ | |
"quantized_plus_lora_fn = jax_gptq.use_quantized(lorax.lora(apply_model))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Y1G-d0yDBn8y" | |
}, | |
"source": [ | |
"### Training\n", | |
"Training isn't actually any different from normal training, since you can just think of `freeze_params` as being a constant argument, but here's a demo for completness.\n", | |
"\n", | |
"First I'll define a toy corpus which demonstrates Alan's love of cats and Grace's dislike of them." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "I3fdjSioBvDO" | |
}, | |
"outputs": [], | |
"source": [ | |
"CATS = ['lions', 'tigers', 'cheetahs', 'cats', 'ocelots', 'kittens']\n", | |
"DOGS = ['wolves', 'dogs', 'coyotes', 'huskies', 'poodles', 'puppies']\n", | |
"\n", | |
"CAT_LOVER = 'Alan'\n", | |
"DOG_LOVER = 'Grace'\n", | |
"\n", | |
"dataset = []\n", | |
"for name, polarity in [(CAT_LOVER, True), (DOG_LOVER, False)]:\n", | |
" liked, disliked = (CATS, DOGS) if polarity else (DOGS, CATS)\n", | |
" for kind in liked:\n", | |
" dataset.append(f'{name}: {kind}? I love them!')\n", | |
" dataset.append(f'{name}: Hey look at those {kind}, that\\'s pretty cool')\n", | |
"\n", | |
" for kind in disliked:\n", | |
" dataset.append(f'{name}: {kind}? I hate them!')\n", | |
" dataset.append(f'{name}: Oh no, some {kind}! How scary!')\n", | |
"\n", | |
"tokenized_data = [jnp.asarray(tokenizer.encode(ex)) for ex in dataset]\n", | |
"max_len = max(ex.shape[0] for ex in tokenized_data)\n", | |
"# Pad the data to speed up jitting. Not worrying about masking due to laziness.\n", | |
"tokenized_data = [jnp.pad(ex, (0, max_len - ex.shape[0])) for ex in tokenized_data]\n", | |
"\n", | |
"jitted_model = jax.jit(quantized_plus_lora_fn)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "NZFLWJgxYqfh" | |
}, | |
"outputs": [], | |
"source": [ | |
"def make_prediction(params, prefix):\n", | |
" tokens = jnp.asarray(tokenizer.encode(prefix))\n", | |
" logits = jitted_model(params, tokens[None]).logits\n", | |
"\n", | |
" logprobs = jnp.exp(jax.nn.log_softmax(logits[0, -1]))\n", | |
" pred_probs, pred_words = jax.lax.top_k(logprobs, 5)\n", | |
"\n", | |
" print(f'Predictions for: \"{prefix}\"')\n", | |
" for i, (word_id, prob) in enumerate(zip(pred_words, pred_probs), 1):\n", | |
" print(f'{i}. {tokenizer.decode([word_id])} - {prob:.2%}')\n", | |
" print()\n", | |
"\n", | |
"test_examples = [\n", | |
" f'{CAT_LOVER}: jaguars? I',\n", | |
" f'{DOG_LOVER}: jaguars? I'\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "yT7hOBnYS-AC" | |
}, | |
"source": [ | |
"Let's look at the next word predictions of the unmodified model:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "eew7ihGJTD85" | |
}, | |
"outputs": [], | |
"source": [ | |
"for ex in test_examples:\n", | |
" make_prediction(combined_params, ex)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "BrSL1MgSDXfO" | |
}, | |
"source": [ | |
"Next we set up a standard training loop. The only difference is that we keep the train/freeze params separate for the optimizer. There's no differences needed for the quantization.\n", | |
"\n", | |
"I'll just train with a batch size of 1 here since I don't want to bother with masking, but the transformed model function is fully compatible with vmap etc." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "52QdkmIxDHk-" | |
}, | |
"outputs": [], | |
"source": [ | |
"def loss_fn(train_params, freeze_params, seq):\n", | |
" inputs = seq[:-1]\n", | |
" targets = seq[1:]\n", | |
"\n", | |
" combined_params = (freeze_params, train_params)\n", | |
" logits = quantized_plus_lora_fn(combined_params, inputs[None]).logits[0]\n", | |
" logprobs = jax.nn.log_softmax(logits)\n", | |
" losses = -jnp.take_along_axis(logprobs, targets[:, None], axis=-1)\n", | |
" return jnp.mean(losses)\n", | |
"\n", | |
"optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)\n", | |
"opt_state = optimizer.init(combined_params[1])\n", | |
"\n", | |
"@jax.jit\n", | |
"def update_fn(combined_params, opt_state, example):\n", | |
" freeze_params, train_params = combined_params\n", | |
"\n", | |
" # The main thing is that we have to split up the params here so that JAX knows what to differentiate with respect to\n", | |
" loss, grads = jax.value_and_grad(loss_fn)(train_params, freeze_params, example)\n", | |
"\n", | |
" updates, opt_state = optimizer.update(grads, opt_state, params=train_params)\n", | |
" new_train_params = optax.apply_updates(train_params, updates)\n", | |
" return (freeze_params, new_train_params), opt_state, loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "cj2d1xIqFJw3" | |
}, | |
"outputs": [], | |
"source": [ | |
"bar = trange(50)\n", | |
"for epoch in bar:\n", | |
" key, = jax.random.split(key, 1)\n", | |
" permutation = jax.random.permutation(key, jnp.arange(len(dataset)))\n", | |
" total_loss = 0\n", | |
" for index in permutation:\n", | |
" example = tokenized_data[index]\n", | |
" combined_params, opt_state, loss = update_fn(combined_params, opt_state, example)\n", | |
" total_loss += loss\n", | |
" bar.set_description(f'Epoch {epoch} - Loss: {total_loss / len(tokenized_data):.3e}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "IMFZwE8qeSUl" | |
}, | |
"source": [ | |
"The trained LoRA parameters give us a model which predicts that Alan will love jaguars, and Grace will hate them:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "GIgThnapFQS6" | |
}, | |
"outputs": [], | |
"source": [ | |
"for example in test_examples:\n", | |
" make_prediction(combined_params, example)\n", | |
" print()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "92W8jCjQeZ9J" | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"gpuType": "T4", | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.10" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"dda94429aac549438c343dd2edc24e0a": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_ee547e0a59eb422d975f02daaf672629", | |
"IPY_MODEL_989350afa53349d39342297bfb4eca95", | |
"IPY_MODEL_9954285510cb4ebab4deb9d6558dd5ed" | |
], | |
"layout": "IPY_MODEL_663b81ccdf0c4b02af24b1cd81b7daa6" | |
} | |
}, | |
"ee547e0a59eb422d975f02daaf672629": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_6bf5f6cd0df14197b8da4303b43cb655", | |
"placeholder": "", | |
"style": "IPY_MODEL_693fb73b9dc0411eab41c5fede4b611c", | |
"value": "Downloading (…)lve/main/config.json: 100%" | |
} | |
}, | |
"989350afa53349d39342297bfb4eca95": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_db0e6b46120844dc8620fbd2a4076d0d", | |
"max": 896, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_9872d010314e492f9a85cbb59edb8cc3", | |
"value": 896 | |
} | |
}, | |
"9954285510cb4ebab4deb9d6558dd5ed": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_1478a80283ff43f193bd9a458a6c8274", | |
"placeholder": "", | |
"style": "IPY_MODEL_199682514a0e43b5a13244a818722d90", | |
"value": " 896/896 [00:00<00:00, 13.1kB/s]" | |
} | |
}, | |
"663b81ccdf0c4b02af24b1cd81b7daa6": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"6bf5f6cd0df14197b8da4303b43cb655": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"693fb73b9dc0411eab41c5fede4b611c": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"db0e6b46120844dc8620fbd2a4076d0d": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"9872d010314e492f9a85cbb59edb8cc3": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"1478a80283ff43f193bd9a458a6c8274": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"199682514a0e43b5a13244a818722d90": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"cdb5e7259f57470296128e76fdf5e6ea": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_9b31610503d84238b9e7a2321ff1b10b", | |
"IPY_MODEL_32d7477699eb461584db4072a7c97052", | |
"IPY_MODEL_6c4eafec00674c838c628394c7c68faa" | |
], | |
"layout": "IPY_MODEL_a63b31cd45b24408b5a9068d1ecd1b4b" | |
} | |
}, | |
"9b31610503d84238b9e7a2321ff1b10b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_715506a0e5694e7ca6061ee99bd92382", | |
"placeholder": "", | |
"style": "IPY_MODEL_3c4dbdf570eb46b690cd500476020064", | |
"value": "Downloading (…)model.bin.index.json: 100%" | |
} | |
}, | |
"32d7477699eb461584db4072a7c97052": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_bdd5160d52de4f9895f8e4dd7f7ec20a", | |
"max": 55432, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_c4eee8d4242045bba3d38597dbb414b3", | |
"value": 55432 | |
} | |
}, | |
"6c4eafec00674c838c628394c7c68faa": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_2c47cb97a18f4ee89bd945b1ca11347d", | |
"placeholder": "", | |
"style": "IPY_MODEL_4a9f0ebd9d7b4192a25ef75250c03deb", | |
"value": " 55.4k/55.4k [00:00<00:00, 691kB/s]" | |
} | |
}, | |
"a63b31cd45b24408b5a9068d1ecd1b4b": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"715506a0e5694e7ca6061ee99bd92382": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"3c4dbdf570eb46b690cd500476020064": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"bdd5160d52de4f9895f8e4dd7f7ec20a": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"c4eee8d4242045bba3d38597dbb414b3": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"2c47cb97a18f4ee89bd945b1ca11347d": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"4a9f0ebd9d7b4192a25ef75250c03deb": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"7d0234ed0d35487e826565df214a2cc9": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_46ad2aad7e474c73be0bfb1fef55a4bd", | |
"IPY_MODEL_6d811cfd512f4916b83640f9297bc1bd", | |
"IPY_MODEL_8c416aa8a81447cba3527d287aeb2f35" | |
], | |
"layout": "IPY_MODEL_8e073507368745088aa356b23f044f0b" | |
} | |
}, | |
"46ad2aad7e474c73be0bfb1fef55a4bd": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_902022c949e448eeb8ce3e31d1ef4153", | |
"placeholder": "", | |
"style": "IPY_MODEL_8bcfb6b0b56d4046a839c0f9ae4c16b1", | |
"value": "Downloading shards: 100%" | |
} | |
}, | |
"6d811cfd512f4916b83640f9297bc1bd": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_7de7565229de41149b770fe25714a847", | |
"max": 2, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_52e158b04b8c4858929bf3739f7c78df", | |
"value": 2 | |
} | |
}, | |
"8c416aa8a81447cba3527d287aeb2f35": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_a12f0daf615547faa649107033a93011", | |
"placeholder": "", | |
"style": "IPY_MODEL_3067249f4bc64398906fceea8d1e5934", | |
"value": " 2/2 [04:42<00:00, 125.83s/it]" | |
} | |
}, | |
"8e073507368745088aa356b23f044f0b": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"902022c949e448eeb8ce3e31d1ef4153": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"8bcfb6b0b56d4046a839c0f9ae4c16b1": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"7de7565229de41149b770fe25714a847": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"52e158b04b8c4858929bf3739f7c78df": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"a12f0daf615547faa649107033a93011": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"3067249f4bc64398906fceea8d1e5934": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"5fe6035fe6e8426580d632bcedc3d19b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_fe81c32260ca47c7bcd4fd6230ed3c81", | |
"IPY_MODEL_ec4a1c28bbb54799a2278ee1e1d65383", | |
"IPY_MODEL_68cc8181fd484af1b9f42b1c9b5f8d75" | |
], | |
"layout": "IPY_MODEL_0d20b32672e34c7295634090e6cebf27" | |
} | |
}, | |
"fe81c32260ca47c7bcd4fd6230ed3c81": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_24f189322a134a20900084de1e08c376", | |
"placeholder": "", | |
"style": "IPY_MODEL_aae1c621a5b94960ace82d52253ccaf7", | |
"value": "Downloading (…)l-00001-of-00002.bin: 100%" | |
} | |
}, | |
"ec4a1c28bbb54799a2278ee1e1d65383": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_ff2779f06f954e35a5eec9a0a02684f1", | |
"max": 9449929179, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_c00fb1fde3f74d82ac6b3b5bc19a9796", | |
"value": 9449929179 | |
} | |
}, | |
"68cc8181fd484af1b9f42b1c9b5f8d75": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_d50444fed9a746ff8281239c4c1a3058", | |
"placeholder": "", | |
"style": "IPY_MODEL_1110d48d7a664208bcd57f3c67aecccd", | |
"value": " 9.45G/9.45G [03:47<00:00, 41.6MB/s]" | |
} | |
}, | |
"0d20b32672e34c7295634090e6cebf27": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"24f189322a134a20900084de1e08c376": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"aae1c621a5b94960ace82d52253ccaf7": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"ff2779f06f954e35a5eec9a0a02684f1": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"c00fb1fde3f74d82ac6b3b5bc19a9796": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"d50444fed9a746ff8281239c4c1a3058": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"1110d48d7a664208bcd57f3c67aecccd": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"c22c1289367046babeb0c9409c181f8b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_eb3f2e923426469e9b92a51187f60f60", | |
"IPY_MODEL_83784cb918974e4c8664f78561ea8cf8", | |
"IPY_MODEL_008aa06de0fd4ec789d3d42e83dbd16e" | |
], | |
"layout": "IPY_MODEL_33217d952d544284937852eeb17323fc" | |
} | |
}, | |
"eb3f2e923426469e9b92a51187f60f60": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_c83d30547b08413f90300abdf318b123", | |
"placeholder": "", | |
"style": "IPY_MODEL_2b8df3f5200e4c9fa59f057c4586a700", | |
"value": "Downloading (…)l-00002-of-00002.bin: 100%" | |
} | |
}, | |
"83784cb918974e4c8664f78561ea8cf8": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_8fc75f12da864868a4d0ab9c1949d1ac", | |
"max": 1949494999, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_b00196f8517c415ea5388bf1a72ffb37", | |
"value": 1949494999 | |
} | |
}, | |
"008aa06de0fd4ec789d3d42e83dbd16e": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_09ae4873193f41baa133879a8d3f5905", | |
"placeholder": "", | |
"style": "IPY_MODEL_bd9db9a757db4ccfa68154524ac95dbb", | |
"value": " 1.95G/1.95G [00:53<00:00, 44.3MB/s]" | |
} | |
}, | |
"33217d952d544284937852eeb17323fc": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"c83d30547b08413f90300abdf318b123": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"2b8df3f5200e4c9fa59f057c4586a700": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"8fc75f12da864868a4d0ab9c1949d1ac": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"b00196f8517c415ea5388bf1a72ffb37": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"09ae4873193f41baa133879a8d3f5905": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"bd9db9a757db4ccfa68154524ac95dbb": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"410478d1c6164c679702f1b959ec6dc6": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_34ae822850454b4b8f92dae185943fe4", | |
"IPY_MODEL_44bd417858fb4982bd97cc07b3f03b5d", | |
"IPY_MODEL_84ddc0bd3bfd4df9a300a0ea01dafc06" | |
], | |
"layout": "IPY_MODEL_60e20079809641db96f76aa61980ac74" | |
} | |
}, | |
"34ae822850454b4b8f92dae185943fe4": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_a82fed6b81274263a1c88e41063356f4", | |
"placeholder": "", | |
"style": "IPY_MODEL_14c39b554476480d8ed17a4188ffc261", | |
"value": "Loading checkpoint shards: 100%" | |
} | |
}, | |
"44bd417858fb4982bd97cc07b3f03b5d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_1b9abcbc8e054a069d9bf0710329928b", | |
"max": 2, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_d85f0a6e82b648b0899e7006b5a7e740", | |
"value": 2 | |
} | |
}, | |
"84ddc0bd3bfd4df9a300a0ea01dafc06": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_df82eb9aaa354a4395a5304ae038a828", | |
"placeholder": "", | |
"style": "IPY_MODEL_96b5cf9d5b7846bfa61523ee1ec4db69", | |
"value": " 2/2 [01:32<00:00, 43.54s/it]" | |
} | |
}, | |
"60e20079809641db96f76aa61980ac74": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"a82fed6b81274263a1c88e41063356f4": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"14c39b554476480d8ed17a4188ffc261": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"1b9abcbc8e054a069d9bf0710329928b": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"d85f0a6e82b648b0899e7006b5a7e740": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"df82eb9aaa354a4395a5304ae038a828": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"96b5cf9d5b7846bfa61523ee1ec4db69": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"3458c2feac6b41ebbf6ed872d9112df0": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_d8cf12f088a94b4da262d8d2c134e89f", | |
"IPY_MODEL_2973bec2ccff49168ba84ddc1f326e8a", | |
"IPY_MODEL_58230fe17b7840c983a59a807a051f2c" | |
], | |
"layout": "IPY_MODEL_b4687344280545029c6505de0d277b97" | |
} | |
}, | |
"d8cf12f088a94b4da262d8d2c134e89f": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_910e96c5ff324caba366ede006d164ab", | |
"placeholder": "", | |
"style": "IPY_MODEL_d8138fd266e44ab3963a274ef7018b0c", | |
"value": "Downloading (…)neration_config.json: 100%" | |
} | |
}, | |
"2973bec2ccff49168ba84ddc1f326e8a": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_0043c749b6ad4d27a251abbf5d3644ed", | |
"max": 147, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_ccf2b81c4a774692980b024885d00872", | |
"value": 147 | |
} | |
}, | |
"58230fe17b7840c983a59a807a051f2c": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_47302718457149c9bd98b2ef0b702d52", | |
"placeholder": "", | |
"style": "IPY_MODEL_83d6be00be514667a08472ba48a7d42f", | |
"value": " 147/147 [00:00<00:00, 4.48kB/s]" | |
} | |
}, | |
"b4687344280545029c6505de0d277b97": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"910e96c5ff324caba366ede006d164ab": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"d8138fd266e44ab3963a274ef7018b0c": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"0043c749b6ad4d27a251abbf5d3644ed": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"ccf2b81c4a774692980b024885d00872": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"47302718457149c9bd98b2ef0b702d52": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"83d6be00be514667a08472ba48a7d42f": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"6d2432a98bf84a77a215f03e69a8a79f": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_9ecd6788ffb04d37ae547f518246fc6b", | |
"IPY_MODEL_da31333b1c844b56830916b794ce9833", | |
"IPY_MODEL_a90c9881a7d0420cb7e230f2b19f25f4" | |
], | |
"layout": "IPY_MODEL_7fc0552e04b247f5a8b5c8dc7ebfbfe1" | |
} | |
}, | |
"9ecd6788ffb04d37ae547f518246fc6b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_5d21f67dc7aa4297b7115299c2589da7", | |
"placeholder": "", | |
"style": "IPY_MODEL_5562aa475ee846a89442b793da39c802", | |
"value": "Downloading spiece.model: 100%" | |
} | |
}, | |
"da31333b1c844b56830916b794ce9833": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_46fd98456a604139a948a97056fb6842", | |
"max": 791656, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_51bebf1ec8fe4e798bbba4bbf158bd6d", | |
"value": 791656 | |
} | |
}, | |
"a90c9881a7d0420cb7e230f2b19f25f4": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_2f2b07d552364493965c7811042d73c5", | |
"placeholder": "", | |
"style": "IPY_MODEL_fab5d066dc794d1dbb17b282081b1d92", | |
"value": " 792k/792k [00:00<00:00, 3.68MB/s]" | |
} | |
}, | |
"7fc0552e04b247f5a8b5c8dc7ebfbfe1": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"5d21f67dc7aa4297b7115299c2589da7": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"5562aa475ee846a89442b793da39c802": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"46fd98456a604139a948a97056fb6842": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"51bebf1ec8fe4e798bbba4bbf158bd6d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"2f2b07d552364493965c7811042d73c5": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"fab5d066dc794d1dbb17b282081b1d92": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"d20b4ddee92c4d22a81bd819e8314fa3": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_556d33b982cc4989a03fc4b29bc7e84a", | |
"IPY_MODEL_03675f31e76a4135a2ae239f3ee42655", | |
"IPY_MODEL_8231b974ae924b68959d606e073d66ad" | |
], | |
"layout": "IPY_MODEL_4bfd7fde14514f41bea4edcb878cfc3c" | |
} | |
}, | |
"556d33b982cc4989a03fc4b29bc7e84a": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_d439f8d48f754ed1976392cadfaac95f", | |
"placeholder": "", | |
"style": "IPY_MODEL_86b9a19a4ecf4e52974e1dc84335aca2", | |
"value": "Downloading (…)/main/tokenizer.json: 100%" | |
} | |
}, | |
"03675f31e76a4135a2ae239f3ee42655": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_1b6c1e021fa64948ae533a11f029bd50", | |
"max": 1389353, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_71b903a4156149bbb714e1f3841f13b2", | |
"value": 1389353 | |
} | |
}, | |
"8231b974ae924b68959d606e073d66ad": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_a011aa10bd374418b030f7cd2f602255", | |
"placeholder": "", | |
"style": "IPY_MODEL_e52faa2810604688844ddf624b0a1400", | |
"value": " 1.39M/1.39M [00:00<00:00, 15.9MB/s]" | |
} | |
}, | |
"4bfd7fde14514f41bea4edcb878cfc3c": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"d439f8d48f754ed1976392cadfaac95f": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"86b9a19a4ecf4e52974e1dc84335aca2": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"1b6c1e021fa64948ae533a11f029bd50": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"71b903a4156149bbb714e1f3841f13b2": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"a011aa10bd374418b030f7cd2f602255": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"e52faa2810604688844ddf624b0a1400": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
} | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from copy import deepcopy | |
from functools import partial, reduce | |
import numpy as np | |
import warnings | |
import jax | |
import jax.numpy as jnp | |
from jax._src.core import Literal | |
from jax.util import safe_map | |
from tqdm import tqdm | |
from .gptq import gptq, pack_matrix, QuantizedMatrix | |
def tree_size_bytes(tree): | |
return jax.tree_util.tree_reduce( | |
lambda x, y: x + y, | |
jax.tree_util.tree_map( | |
lambda x: x.size * x.itemsize, | |
tree | |
), | |
0 | |
) | |
def quantize( | |
fn, | |
params, | |
inputs, | |
block_size=128, | |
actorder=False, | |
damping=0.01, | |
use_quantized_activations=True, | |
use_fp64=False, | |
use_params_fp32=False | |
): | |
""" | |
Run the GPT-Q algorithm on a function to produce quantized versions of its parameters | |
Arguments: | |
fn: The function to be transformed. It should take two arguments: | |
1. A pytree of parameters to be quantized. This corresponds to the `params` pytree from libraries such as Flax/Haiku | |
2. A pytree of other arguments. If the original model takes more than one extra argument, you can write a wrapper which takes a tuple as the second argument. TODO: handle varargs | |
params: The params pytree. Buffers in this tree may be freed to save memory, so do not re-use it after calling this function. | |
inputs: A list of batches of inputs. If your model needs to be vmapped to handle batches, do that before calling quantize. | |
""" | |
with jax.disable_jit(): | |
jaxpr_args = (params, inputs[0]) | |
if use_params_fp32: | |
jaxpr_args = jax.tree_util.tree_map( | |
lambda x: jax.ShapeDtypeStruct(x.shape, jnp.float32) if x.dtype.kind == 'f' else x, | |
jaxpr_args | |
) | |
closed_jaxpr = jax.make_jaxpr(fn)(*jaxpr_args) | |
params = jax.device_put(params, jax.devices('cpu')[0]) | |
inputs = jax.device_put(inputs, jax.devices('cpu')[0]) | |
argnums = set() | |
param_args, param_struct = jax.tree_util.tree_flatten(params) | |
input_args = [jax.tree_util.tree_leaves(inp) for inp in inputs] | |
input_args = [list(arg) for arg in zip(*input_args)] | |
argnums = set(range(0, len(param_args))) | |
result = _eval_and_quantize( | |
closed_jaxpr.jaxpr, | |
closed_jaxpr.literals, | |
argnums, | |
*param_args, | |
*input_args, | |
block_size=block_size, | |
actorder=actorder, | |
damping=damping, | |
use_quantized_activations=use_quantized_activations, | |
use_fp64=use_fp64, | |
use_params_fp32=use_params_fp32 | |
) | |
for ind, quantized_param in result.items(): | |
param_args[ind] = quantized_param | |
return jax.tree_util.tree_unflatten(param_struct, param_args) | |
def _get_delete_points(jaxpr): | |
deps = defaultdict(set) | |
for i, eqn in enumerate(jaxpr.eqns): | |
for var in set(v for v in eqn.invars if not isinstance(v, Literal)): | |
deps[var].add(i) | |
deps = dict(deps) | |
delete_vars = [] | |
for i, eqn in enumerate(jaxpr.eqns): | |
eqn_delete = [] | |
for var in set(v for v in eqn.invars if not isinstance(v, Literal)): | |
deps[var].remove(i) | |
if not deps[var]: | |
eqn_delete.append(var) | |
del deps[var] | |
delete_vars.append(eqn_delete) | |
return delete_vars | |
def _maybe_delete(val): | |
if not val.is_deleted(): | |
val.device_buffer.delete() | |
def _eval_and_quantize( | |
jaxpr, | |
consts, | |
argnums, | |
*args, | |
block_size=128, | |
actorder=False, | |
damping=0.01, | |
use_quantized_activations=True, | |
use_fp64=False, | |
use_params_fp32=False | |
): | |
tpu = jax.devices() | |
#cpu = jax.devices('cpu')[0] | |
#gpu = jax.devices('gpu')[0] | |
# Args are all either params or lists of tensors | |
quantized_results = {} | |
name_to_pos = {} | |
n_batches = len(next(a for i, a in enumerate(args) if i not in argnums)) | |
# Everything in here should be on GPU | |
envs = [{} for _ in range(n_batches)] | |
# Map from var name to a tuple of value, original_name, and a stack of transformations to map it back to orig param shape | |
param_env = {} | |
for index, name in enumerate(jaxpr.invars): | |
if index in argnums: | |
param_env[name] = (args[index], name, ()) | |
name_to_pos[name] = index | |
else: | |
for i in range(n_batches): | |
envs[i][name] = args[index][i] | |
def delete(name): | |
if name not in envs[0]: | |
return | |
for env in envs: | |
env[name].device_buffer.delete() | |
del env[name] | |
delete_points = _get_delete_points(jaxpr) | |
const_env = {name: val for name, val in zip(jaxpr.constvars, consts)} | |
pos = 0 | |
bar = tqdm(desc='Quantizing') | |
while True: | |
bar.update(1) | |
next_pos, needed_names, matmul_handler, updated_param_env = update_params_to_next_matmul( | |
eqns=jaxpr.eqns, | |
start_pos=pos, | |
delete_points=delete_points, | |
param_env=param_env, | |
env=envs[0] | |
) | |
if next_pos is None: | |
break | |
block_param_env = { | |
name: jax.device_put(param_env[name][0], gpu) | |
for name in needed_names if name in param_env | |
} | |
if use_params_fp32: | |
for k, v in block_param_env.items(): | |
if v.dtype.kind == 'f': | |
block_param_env[k] = v.astype(jnp.float32) | |
print(f'Current env size: {tree_size_bytes(envs):.2e} bytes') | |
print(f'Current param env size: {tree_size_bytes(block_param_env):.2e} bytes') | |
delete_keys = set(var for i in range(pos, next_pos) for var in delete_points[i]) | |
segment_eqns = jaxpr.eqns[pos:next_pos] | |
# If a parameter has been transformed keep it in the param env instead of the individual envs | |
drop_env_keys = set(k for k in updated_param_env if k not in param_env) | |
missing_keys = set(k for k in param_env if k not in updated_param_env) | |
block_fn = jax.jit(partial(run_segment, segment_eqns, pos, delete_points, drop_env_keys)) | |
for i, env in enumerate(envs): | |
#gpu_env = jax.device_put(env, gpu) | |
tpu_env = jax.device_put(env, tpu) | |
new_env = block_fn(block_param_env, tpu_env, const_env) | |
envs[i] = new_env | |
#envs[i] = jax.device_put(new_env, cpu) | |
#jax.tree_map(_maybe_delete, (gpu_env, new_env)) | |
for param in block_param_env.values(): | |
param.device_buffer.delete() | |
del block_param_env | |
param_env = updated_param_env | |
#(jax.device_put(0., gpu) + 0).block_until_ready() | |
matmul_eqn = jaxpr.eqns[next_pos] | |
all_args = [] | |
if sum(argname in param_env for argname in matmul_eqn.invars) > 1: | |
raise NotImplementedError('Currently only one quantize target is supported per op') | |
quantize_argname = next(argname for argname in matmul_eqn.invars if argname in param_env) | |
for argname in matmul_eqn.invars: | |
if argname in param_env: | |
all_args.append(param_env[argname][0]) | |
else: | |
all_args.append([env[argname] for env in envs]) | |
all_args = [jax.device_put(arg, gpu) for arg in all_args] | |
handler_coro = matmul_handler(all_args) | |
w, xs = next(handler_coro) | |
quantized_w, quantize_params = gptq( | |
W=w, | |
xs=xs, | |
block_size=block_size, | |
actorder=actorder, | |
damping=damping, | |
use_fp64=use_fp64 | |
) | |
assert quantized_w.shape == w.shape | |
try: | |
handler_coro.send((quantized_w, quantize_params['scale'], quantize_params['zero'])) | |
assert False, 'Handler should have stopped' | |
except StopIteration as e: | |
quantized_w, quantize_params['scale'], quantize_params['zero'], contraction_axis = e.value | |
outvars = jaxpr.eqns[next_pos].outvars | |
delete_indices = [i for i, name in enumerate(matmul_eqn.invars) if name != quantize_argname] | |
do_eval = jax.jit(partial(eval_eqn, matmul_eqn)) | |
matmul_w_arg = quantized_w if use_quantized_activations else param_env[quantize_argname][0] | |
if use_params_fp32: | |
matmul_w_arg = matmul_w_arg.astype(jnp.float32) | |
matmul_w_arg = jax.device_put(matmul_w_arg, gpu) | |
for env in envs: | |
gpu_args = [ | |
matmul_w_arg | |
if argname == quantize_argname else | |
env[argname] | |
for argname in matmul_eqn.invars | |
] | |
gpu_args = jax.device_put(gpu_args, gpu) | |
results = do_eval(*gpu_args) | |
if tree_size_bytes(results) > 1e8: | |
# This should offload stuff like the final logits to the CPU | |
cpu_results = jax.device_put(results, cpu) | |
jax.tree_map(lambda x: x.is_deleted() or x.device_buffer.delete(), results) | |
results = cpu_results | |
if matmul_eqn.primitive.multiple_results: | |
for outvar, value in zip(outvars, results): | |
env[outvar] = value | |
else: | |
env[outvars[0]] = results | |
for name in delete_points[next_pos]: | |
if name in env: | |
_maybe_delete(env[name]) | |
del env[name] | |
#for i in delete_indices: | |
# gpu_args[i].device_buffer.delete() | |
#(jax.device_put(0., gpu) + 0).block_until_ready() | |
#for name in delete_points[next_pos]: | |
# delete(name) | |
# TODO: Instead of catching duplicate quantizations here avoid doing the calculation in the first place | |
orig_w, orig_name, inv_transforms = param_env[quantize_argname] | |
write_arg = name_to_pos[orig_name] | |
if write_arg not in quantized_results: | |
packed_result = pack_matrix(quantized_w, quantize_params, contraction_axis) | |
un_transformed = reduce(lambda x, f: f(x), inv_transforms, packed_result) | |
quantized_results[write_arg] = jax.device_put(un_transformed, cpu) | |
if quantize_argname not in delete_points[next_pos]: | |
cpu_quantized_w = jax.device_put(quantized_w, cpu) | |
param_env[quantize_argname] = cpu_quantized_w, orig_name, inv_transforms | |
orig_w.device_buffer.delete() | |
elif quantize_argname in delete_points[next_pos]: | |
orig_w.device_buffer.delete() | |
del param_env[quantize_argname] | |
quantized_w.device_buffer.delete() | |
#(jax.device_put(0., gpu) + 0).block_until_ready() | |
pos = next_pos + 1 | |
return quantized_results | |
def update_params_to_next_matmul(eqns, start_pos, delete_points, param_env, env): | |
new_param_env = {k: v for k, v in param_env.items()} | |
env_shapes = {k: jax.ShapeDtypeStruct(v.shape, v.dtype) for k, v in env.items()} | |
needed_names = set() | |
for i, eqn in enumerate(eqns[start_pos:], start_pos): | |
invars = eqn.invars | |
op_name = eqn.primitive.name | |
if op_name in PARAM_TRANSFORMS: | |
arg, = invars | |
needed_names.add(arg) | |
if arg in new_param_env and len(new_param_env[arg][0].shape) > 1: | |
val, orig_name, transforms = new_param_env[arg] | |
new_transform = PARAM_TRANSFORMS[op_name](eqn, val) | |
new_name, = eqn.outvars | |
new_val = eval_eqn(eqn, val) | |
new_param_env[new_name] = new_val, orig_name, (transforms + (new_transform,)) | |
if arg in delete_points[i]: #TODO: Become certain that making this just a soft check was fine | |
del new_param_env[arg] | |
else: | |
warnings.warn(f'Transformation `{op_name}` is applied to a target parameter of shape {new_param_env[arg][0].shape} which is later reused. This may lead to this parameter not being quantized, or it being quantized poorly.') | |
continue | |
arg_shapes = [invar.aval for invar in invars] | |
args_are_targets = [( | |
False if isinstance(v, Literal) else | |
(v in new_param_env and len(new_param_env[v][0].shape) > 1) | |
) for v in invars] | |
if any(args_are_targets): | |
if op_name == 'pjit': | |
warnings.warn(f'Quantization does not descend into pjit') | |
if op_name in PRIMITIVE_TO_MATMUL: | |
predicate, handler = PRIMITIVE_TO_MATMUL[op_name] | |
if predicate(eqn, args_are_targets, arg_shapes): | |
return i, needed_names, partial(handler, eqn, args_are_targets), new_param_env | |
else: | |
warnings.warn(f'Operation {eqn.primitive.name} not supported for quantization') | |
out_shapes = jax.eval_shape(partial(eval_eqn, eqn), *arg_shapes) | |
if not eqn.primitive.multiple_results: | |
out_shapes = [out_shapes] | |
safe_map(env_shapes.__setitem__, eqn.outvars, out_shapes) | |
needed_names.update(v for v in invars if not isinstance(v, Literal)) | |
return None, needed_names, None, None | |
def run_segment(eqns, start_pos, delete_points, drop_env_keys, param_env, env, const_env): | |
env = dict(env) | |
def read(v): | |
if isinstance(v, Literal): | |
return v.val | |
if v in param_env: | |
return param_env[v] | |
if v in env: | |
return env[v] | |
return const_env[v] | |
def write(v, val): | |
env[v] = val | |
for i, eqn in enumerate(eqns, start_pos): | |
eqn_args = safe_map(read, eqn.invars) | |
ans = eval_eqn(eqn, *eqn_args) | |
if eqn.primitive.multiple_results: | |
safe_map(write, eqn.outvars, ans) | |
else: | |
write(eqn.outvars[0], ans) | |
for varname in delete_points[i]: | |
if varname in env: | |
del env[varname] | |
for key in drop_env_keys: | |
env.pop(key, None) | |
return env | |
def dot_general_predicate(eqn, args_are_targets, args): | |
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = eqn.params['dimension_numbers'] | |
if sum(args_are_targets) > 1: | |
warnings.warn('Quantizing two parameters which are multiplied together is not supported') | |
return False | |
if lhs_batch or rhs_batch: | |
warnings.warn('Quantizing batched matmuls is not supported') | |
return False | |
if len(lhs_contract) > 1 or len(rhs_contract) > 1: | |
warnings.warn('Quantizing dots with more than one contraction is not supported') | |
return False | |
return True | |
@partial(jax.jit, static_argnums=(1, 2)) | |
def permute_to_matrix(w, permutation, keep_first): | |
w = jnp.transpose(w, permutation) | |
out_shape = (w.shape[0], -1) if keep_first else (-1, w.shape[-1]) | |
w = jnp.reshape(w, out_shape) | |
return w | |
@partial(jax.jit, static_argnums=(1, 2)) | |
def to_original_shape(w, shape, restore_permutation): | |
return jnp.transpose( | |
jnp.reshape(w, shape), | |
restore_permutation | |
) | |
def handle_dot_general(eqn, args_are_targets, args): | |
lhs, rhs = args | |
((lhs_contract,), (rhs_contract,)), _ = eqn.params['dimension_numbers'] | |
if args_are_targets[0]: | |
w, xs = lhs, rhs | |
w_contract, x_contract = lhs_contract, rhs_contract | |
else: | |
w, xs = rhs, lhs | |
w_contract, x_contract = rhs_contract, lhs_contract | |
orig_w_shape = w.shape | |
w_permutation = None | |
if w_contract != 0 or len(w.shape) > 2: | |
w_permutation = tuple([w_contract, *(i for i in range(len(w.shape)) if i != w_contract)]) | |
w = permute_to_matrix(w, w_permutation, True) | |
assert isinstance(xs, list) | |
x_permutation = None | |
if x_contract != len(xs[0].shape) - 1: | |
x_permutation = tuple([*(i for i in range(len(xs[0].shape)) if i != x_contract), x_contract]) | |
prepared_xs = [] | |
for x in xs: | |
if x_permutation is not None: | |
x = permute_to_matrix(x, x_permutation, False) | |
prepared_xs.append(x) | |
quantized_w, scales, zeros = yield w, prepared_xs | |
if w_permutation: | |
unpermute = tuple(np.argsort(w_permutation)) | |
shape = tuple(orig_w_shape[i] for i in w_permutation) | |
quantized_w = to_original_shape(quantized_w, shape, unpermute) | |
scale_shape = tuple(d for i, d in enumerate(orig_w_shape) if i != w_contract) | |
scales = jnp.reshape(scales, scale_shape) | |
zeros = jnp.reshape(zeros, scale_shape) | |
return quantized_w, scales, zeros, int(w_contract) | |
def conv_predicate(eqn, args_are_targets, args): | |
inp_is_target, kernel_is_target = args_are_targets | |
if inp_is_target: | |
warnings.warn('Only quantizing the kernel of a conv is supported, not the input') | |
if not kernel_is_target: | |
return False | |
params = eqn.params | |
if any(val != 1 for val in params['window_strides']): | |
warnings.warn('Currently only quantizing convs with stride 1 is supported') | |
return False | |
if any(val != 1 for val in params['rhs_dilation']): | |
warnings.warn('Currently only quantizing convs with dilation 1 is supported') | |
return False | |
if params['feature_group_count'] != 1: | |
warnings.warn('Currently only quantizing convs with feature group count 1 is supported') | |
return False | |
if params['batch_group_count'] != 1: | |
warnings.warn('Currently only quantizing convs with batch group count 1 is supported') | |
return False | |
# Each is: Batch, feature, spatial... | |
kernel_spatial_dims = params['dimension_numbers'][1][2:] | |
kernel_shape = args[1].shape | |
for spatial_dim in kernel_spatial_dims: | |
if kernel_shape[spatial_dim] != 1: | |
warnings.warn('Currently only quantizing convs with 1x..x1 kernels are supported') | |
return False | |
return True | |
def handle_conv(eqn, args_are_targets, args): | |
inps, kernel = args | |
inp_shape = inps[0].shape | |
kernel_shape = kernel.shape | |
(inp_batch_dim, inp_feature_dim, inp_spatial_dims), (kernel_out_dim, kernel_in_dim, *kernel_spatial_dims), _ = eqn.params['dimension_numbers'] | |
flat_kernel = jnp.squeeze(kernel, kernel_spatial_dims) | |
needs_transpose = kernel_out_dim < kernel_in_dim | |
if needs_transpose: | |
flat_kernel = flat_kernel.T | |
inp_permutation = None | |
if inp_feature_dim != len(inp_shape) - 1: | |
inp_permutation = tuple([*(i for i in range(len(inp_shape)) if i != inp_feature_dim), inp_feature_dim]) | |
prepared_inps = [] | |
for inp in inps: | |
if inp_permutation is not None: | |
inp = permute_to_matrix(inp, inp_permutation, False) | |
prepared_inps.append(inp) | |
quantized_kernel, scales, zeros = yield flat_kernel, prepared_inps | |
if needs_transpose: | |
quantized_kernel = quantized_kernel.T | |
for dim in sorted(kernel_spatial_dims): | |
quantized_kernel = jnp.expand_dims(quantized_kernel, dim) | |
scale_dim = dim if dim < inp_feature_dim else dim - 1 | |
scales = jnp.expand_dims(scales, scale_dim) | |
zeros = jnp.expand_dims(zeros, scale_dim) | |
return quantized_kernel, scales, zeros, kernel_in_dim | |
def eval_eqn(eqn, *args): | |
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) | |
ans = eqn.primitive.bind(*subfuns, *args, **bind_params) | |
return ans | |
PRIMITIVE_TO_MATMUL = { | |
'dot_general': (dot_general_predicate, handle_dot_general), | |
'conv_general_dilated': (conv_predicate, handle_conv) | |
} | |
def inverse_transpose(eqn, arg): | |
unpermute = tuple(np.argsort(eqn.params['permutation'])) | |
def inverse(quantized_matrix): | |
prev_contract_axis = quantized_matrix.contraction_axis | |
new_contraction_axis = unpermute[prev_contract_axis] | |
new_int_weight = jax.lax.transpose(quantized_matrix.int_weight, permutation=unpermute) | |
unpermute_scale = [ | |
i if i < prev_contract_axis else i - 1 | |
for i in unpermute | |
if i != prev_contract_axis | |
] | |
new_scale = jax.lax.transpose(quantized_matrix.scale, permutation=unpermute_scale) | |
new_zero = jax.lax.transpose(quantized_matrix.zero, permutation=unpermute_scale) | |
return QuantizedMatrix( | |
int_weight=new_int_weight, | |
scale=new_scale, | |
zero=new_zero, | |
contraction_axis=new_contraction_axis | |
) | |
return inverse | |
def inverse_convert_type(eqn, arg): | |
return lambda x: x | |
PARAM_TRANSFORMS = { | |
'transpose': inverse_transpose, | |
'convert_element_type': inverse_convert_type, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment