Skip to content

Instantly share code, notes, and snippets.

@keyboardAnt
Last active November 1, 2023 19:29
Show Gist options
  • Save keyboardAnt/322c4263f231387cad089ed15b0394db to your computer and use it in GitHub Desktop.
Save keyboardAnt/322c4263f231387cad089ed15b0394db to your computer and use it in GitHub Desktop.
lm_format_enforcer_vllm_integration.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/keyboardAnt/322c4263f231387cad089ed15b0394db/lm_format_enforcer_vllm_integration.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wv1vqZgW-mLt"
},
"source": [
"# LM Format Enforcer Integration with vLLM\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_vllm_integration.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>\n",
"\n",
"This notebook shows how you can integrate with the vLLM library. vLLM does not currently have an API for token filtering, so we have to do some monkey patching to expose the functionality.\n",
"\n",
"## Setting up the COLAB runtime (user action required)\n",
"\n",
"This colab-friendly notebook is targeted at demoing the enforcer on LLAMA2. It can run on a free GPU on Google Colab.\n",
"Make sure that your runtime is set to GPU:\n",
"\n",
"Menu Bar -> Runtime -> Change runtime type -> T4 GPU (at the time of writing this notebook). [Guide here](https://www.codesansar.com/deep-learning/using-free-gpu-tpu-google-colab.htm)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7uUS5qUj-mLv"
},
"source": [
"## Gathering huggingface credentials (user action required)\n",
"\n",
"We begin by installing the dependencies. This demo uses llama2, so you will have to create a free huggingface account, request access to the llama2 model, create an access token, and insert it when executing the next cell will request it.\n",
"\n",
"Links:\n",
"\n",
"- [Request access to llama model](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). See the \"Access Llama 2 on Hugging Face\" section.\n",
"- [Create huggingface access token](https://huggingface.co/settings/tokens)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8TnReyIj-mLv",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "725a41f8-6e10-4c42-8850-234de921671f"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: vllm in /usr/local/lib/python3.10/dist-packages (0.2.1.post1)\n",
"Requirement already satisfied: lm-format-enforcer in /usr/local/lib/python3.10/dist-packages (0.4.3)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm) (1.11.1.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm) (2.7.1)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from vllm) (1.5.3)\n",
"Requirement already satisfied: pyarrow in /usr/local/lib/python3.10/dist-packages (from vllm) (9.0.0)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm) (0.1.99)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from vllm) (1.23.5)\n",
"Requirement already satisfied: torch==2.0.1 in /usr/local/lib/python3.10/dist-packages (from vllm) (2.0.1)\n",
"Requirement already satisfied: transformers>=4.34.0 in /usr/local/lib/python3.10/dist-packages (from vllm) (4.34.1)\n",
"Requirement already satisfied: xformers==0.0.22 in /usr/local/lib/python3.10/dist-packages (from vllm) (0.0.22)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm) (0.104.1)\n",
"Requirement already satisfied: uvicorn[standard] in /usr/local/lib/python3.10/dist-packages (from vllm) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm) (1.10.13)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (3.12.4)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (4.8.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (3.2)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (3.1.2)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (11.7.99)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (11.7.99)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (11.7.101)\n",
"Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (8.5.0.96)\n",
"Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (11.10.3.66)\n",
"Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (10.9.0.58)\n",
"Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (10.2.10.91)\n",
"Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (11.4.0.1)\n",
"Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (11.7.4.91)\n",
"Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (2.14.3)\n",
"Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (11.7.91)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1->vllm) (2.0.0)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1->vllm) (67.7.2)\n",
"Requirement already satisfied: wheel in /usr/local/lib/python3.10/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1->vllm) (0.41.2)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch==2.0.1->vllm) (3.27.7)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch==2.0.1->vllm) (17.0.4)\n",
"Requirement already satisfied: interegular>=0.3.2 in /usr/local/lib/python3.10/dist-packages (from lm-format-enforcer) (0.3.2)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm) (4.19.1)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm) (1.0.7)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm) (23.2)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm) (3.20.3)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm) (6.0.1)\n",
"Requirement already satisfied: aiosignal in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm) (1.3.1)\n",
"Requirement already satisfied: frozenlist in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm) (1.4.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm) (2.31.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.34.0->vllm) (0.17.3)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.34.0->vllm) (2023.6.3)\n",
"Requirement already satisfied: tokenizers<0.15,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.34.0->vllm) (0.14.1)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.34.0->vllm) (0.4.0)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.34.0->vllm) (4.66.1)\n",
"Requirement already satisfied: anyio<4.0.0,>=3.7.1 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm) (3.7.1)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->vllm) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->vllm) (2023.3.post1)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]->vllm) (0.14.0)\n",
"Requirement already satisfied: httptools>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]->vllm) (0.6.1)\n",
"Requirement already satisfied: python-dotenv>=0.13 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]->vllm) (1.0.0)\n",
"Requirement already satisfied: uvloop!=0.15.0,!=0.15.1,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]->vllm) (0.19.0)\n",
"Requirement already satisfied: watchfiles>=0.13 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]->vllm) (0.21.0)\n",
"Requirement already satisfied: websockets>=10.4 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]->vllm) (12.0)\n",
"Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<4.0.0,>=3.7.1->fastapi->vllm) (3.4)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<4.0.0,>=3.7.1->fastapi->vllm) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<4.0.0,>=3.7.1->fastapi->vllm) (1.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers>=4.34.0->vllm) (2023.6.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->vllm) (1.16.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.0.1->vllm) (2.1.3)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm) (23.1.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm) (2023.7.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm) (0.30.2)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm) (0.10.6)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm) (3.3.1)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm) (2023.7.22)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.0.1->vllm) (1.3.0)\n",
"\n",
" _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|\n",
" _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n",
" _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|\n",
" _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n",
" _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|\n",
" \n",
" A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.\n",
" Setting a new token will erase the existing one.\n",
" To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .\n",
"Token: \n",
"Add token as git credential? (Y/n) \n",
"Token is valid (permission: read).\n",
"\u001b[1m\u001b[31mCannot authenticate through git-credential as no helper is defined on your machine.\n",
"You might have to re-authenticate when pushing to the Hugging Face Hub.\n",
"Run the following command in your terminal in case you want to set the 'store' credential helper as default.\n",
"\n",
"git config --global credential.helper store\n",
"\n",
"Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.\u001b[0m\n",
"Token has not been saved to git credential helper.\n",
"Your token has been saved to /root/.cache/huggingface/token\n",
"Login successful\n"
]
}
],
"source": [
"!pip install vllm lm-format-enforcer\n",
"!huggingface-cli login\n",
"\n",
"# When running from source / developing the library, use this instead\n",
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"# import sys\n",
"# import os\n",
"# sys.path.append(os.path.abspath('..'))\n",
"## os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sB4vlD90-mLw"
},
"source": [
"## Creating a custom sampler that filters tokens\n",
"\n",
"We introduce a subclass of vLLM's ```SamplingParams``` that also accepts a token filtering function, with the same API as Huggingface Transformers\n",
"\n",
"```prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]]```\n",
"\n",
"We then introduce the function ```_apply_allowed_token_filters()``` that applies the filter functions to the logits (sets them to negative infinity if not allowed) to requests that contain a filter function.\n",
"\n",
"We hope that in future releases of vLLM, this (or similar) will be part of vLLM's ```Sampler``` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tx_2Aol7-mLw"
},
"outputs": [],
"source": [
"import vllm\n",
"import torch\n",
"from typing import List, Callable, Optional\n",
"from vllm.sampling_params import SamplingParams\n",
"from vllm.model_executor.input_metadata import InputMetadata\n",
"\n",
"class SamplingParamsWithFilterFunction(SamplingParams):\n",
" logits_allowed_tokens_filter_function: Optional[Callable[[int, torch.Tensor], List[int]]]\n",
"\n",
"def _apply_allowed_token_filters(logits: torch.Tensor,\n",
" input_metadata: InputMetadata) -> torch.Tensor:\n",
" num_seqs, vocab_size = logits.shape\n",
" logits_row_idx = 0\n",
" for seq_ids, sampling_params in input_metadata.seq_groups:\n",
" if isinstance(sampling_params, SamplingParamsWithFilterFunction):\n",
" filter_function = sampling_params.logits_allowed_tokens_filter_function\n",
" else:\n",
" filter_function = None\n",
" for seq_id in seq_ids:\n",
" if filter_function is not None:\n",
" output_token_ids = input_metadata.seq_data[seq_id].output_token_ids\n",
" output_token_tensor = torch.tensor(output_token_ids, dtype=torch.long)\n",
" allowed_tokens = filter_function(logits_row_idx, output_token_tensor)\n",
" logits_add_factor = torch.zeros(vocab_size, dtype=logits.dtype, device=logits.device)\n",
" logits_add_factor[:] = float('-inf')\n",
" logits_add_factor[allowed_tokens] = 0\n",
" logits[logits_row_idx] += logits_add_factor\n",
" logits_row_idx += 1\n",
" assert logits_row_idx == num_seqs\n",
" return logits\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iS3Bk3-k-mLw"
},
"source": [
"In order to integrate this function with the ```Sampler``` class, we have to change its ```forward()``` function to call it. Since we are not modifying vLLM itself, we will do this with monkey patching.\n",
"\n",
"Other than the line\n",
"```\n",
"logits = _apply_allowed_token_filters(logits, input_metadata)\n",
"```\n",
"this is a 100% copy of the original ```Sampler.forward()``` function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9O_kyZ1J-mLw"
},
"outputs": [],
"source": [
"from vllm.model_executor.layers.sampler import SamplerOutput, _prune_hidden_states, _get_logits, _get_output_tokens, _get_penalties, _apply_penalties, _get_temperatures, _get_top_p_top_k, _apply_top_p_top_k, _sample, _get_logprobs, _build_sampler_output, _SAMPLING_EPS\n",
"\n",
"from typing import Optional\n",
"\n",
"def patched_forward(\n",
" self,\n",
" embedding: torch.Tensor,\n",
" hidden_states: torch.Tensor,\n",
" input_metadata: InputMetadata,\n",
" embedding_bias: Optional[torch.Tensor] = None,\n",
" ) -> SamplerOutput:\n",
" # Get the hidden states that we use for sampling.\n",
" hidden_states = _prune_hidden_states(hidden_states, input_metadata)\n",
"\n",
" # Get the logits for the next tokens.\n",
" logits = _get_logits(hidden_states, embedding, embedding_bias,\n",
" self.vocab_size)\n",
"\n",
" # Apply presence and frequency penalties.\n",
" output_tokens = _get_output_tokens(input_metadata)\n",
" assert len(output_tokens) == logits.shape[0]\n",
" presence_penalties, frequency_penalties = _get_penalties(\n",
" input_metadata)\n",
" assert len(presence_penalties) == logits.shape[0]\n",
" assert len(frequency_penalties) == logits.shape[0]\n",
" logits = _apply_penalties(logits, output_tokens, presence_penalties,\n",
" frequency_penalties)\n",
"\n",
" ### LM FORMAT ENFORCER MONKEY PATCH START\n",
" logits = _apply_allowed_token_filters(logits, input_metadata)\n",
" ### LM FORMAT ENFORCER MONKEY PATCH END\n",
"\n",
" # Apply temperature scaling.\n",
" temperatures = _get_temperatures(input_metadata)\n",
" assert len(temperatures) == logits.shape[0]\n",
" if any(t != 1.0 for t in temperatures):\n",
" t = torch.tensor(temperatures,\n",
" dtype=logits.dtype,\n",
" device=logits.device)\n",
" # Use in-place division to avoid creating a new tensor.\n",
" logits.div_(t.unsqueeze(dim=1))\n",
"\n",
" # Apply top-p and top-k truncation.\n",
" top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)\n",
" assert len(top_ps) == len(top_ks) == logits.shape[0]\n",
" do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)\n",
" do_top_k = any(k != self.vocab_size for k in top_ks)\n",
" if do_top_p or do_top_k:\n",
" logits = _apply_top_p_top_k(logits, top_ps, top_ks)\n",
"\n",
" # We use float32 for probabilities and log probabilities.\n",
" # Compute the probabilities.\n",
" probs = torch.softmax(logits, dim=-1, dtype=torch.float)\n",
" # Compute the log probabilities.\n",
" # Use log_softmax to ensure numerical stability.\n",
" logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)\n",
"\n",
" # Sample the next tokens.\n",
" sample_results = _sample(probs, logprobs, input_metadata)\n",
" # Get the logprobs query results.\n",
" prompt_logprobs, sample_logprobs = _get_logprobs(\n",
" logprobs, input_metadata, sample_results)\n",
" return _build_sampler_output(sample_results, input_metadata,\n",
" prompt_logprobs, sample_logprobs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LoGzJpcv-mLx"
},
"source": [
"We load the model, as is normally done with vLLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vin9d_x0-mLx",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "cd46270b-70e7-4035-9d42-051cca9e93d2"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"INFO 10-31 15:50:34 llm_engine.py:72] Initializing an LLM engine with config: model='NousResearch/Llama-2-7b-chat-hf', tokenizer='NousResearch/Llama-2-7b-chat-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, seed=0)\n",
"INFO 10-31 15:50:34 tokenizer.py:31] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.\n",
"INFO 10-31 15:51:52 llm_engine.py:207] # GPU blocks: 26, # CPU blocks: 512\n"
]
}
],
"source": [
"# model_id = 'meta-llama/Llama-2-7b-chat-hf'\n",
"# model_id = 'facebook/opt-125m'\n",
"model_id = \"NousResearch/Llama-2-7b-chat-hf\"\n",
"llm = vllm.LLM(model=model_id)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k0ES1Dfg-mLx"
},
"source": [
"If the previous cell executed successfully, you have propertly set up your Colab runtime and huggingface account!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IEDtbJfj-mLx"
},
"source": [
"A few helper functions to make display nicer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i64sPcrt-mLx"
},
"outputs": [],
"source": [
"from IPython.display import display, Markdown\n",
"\n",
"def display_header(text):\n",
" display(Markdown(f'**{text}**'))\n",
"\n",
"def display_content(text):\n",
" display(Markdown(f'```\\n{text}\\n```'))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uMsbo8Zj-mLy"
},
"source": [
"## Setting up the prompt for the specific language model\n",
"\n",
"We set up the prompting style according to the [Llama2 demo](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/app.py). We simplify the implementation a bit as we don't need chat history for this demo."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fvTpXvij-mLy"
},
"outputs": [],
"source": [
"DEFAULT_SYSTEM_PROMPT = \"\"\"\\\n",
"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\\\n",
"\"\"\"\n",
"\n",
"def get_prompt(message: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:\n",
" return f'<s>[INST] <<SYS>>\\n{system_prompt}\\n<</SYS>>\\n\\n{message} [/INST]'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MQRmHSYH-mLy"
},
"source": [
"## Activating the monkey patch and creating the generation function\n",
"\n",
"We monkey-patch the ```Sampler``` class with our custom ```forward()``` method, using ```unittest.mock```.\n",
"\n",
"We use our sampling params in order to sent the specific filter function with the request. Different requests can have different format enforcers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gNFmH0rz-mLy"
},
"outputs": [],
"source": [
"from lmformatenforcer import CharacterLevelParser\n",
"from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn\n",
"from unittest import mock\n",
"\n",
"DEFAULT_MAX_NEW_TOKENS = 100\n",
"\n",
"def vllm_with_character_level_parser(llm: vllm.LLM, prompt: str, parser: Optional[CharacterLevelParser] = None) -> str:\n",
" with mock.patch.object(vllm.model_executor.layers.sampler.Sampler, 'forward', patched_forward):\n",
" prefix_function = build_transformers_prefix_allowed_tokens_fn(llm.get_tokenizer(), parser) if parser else None\n",
" sampling_params = SamplingParamsWithFilterFunction()\n",
" sampling_params.max_tokens = DEFAULT_MAX_NEW_TOKENS\n",
" sampling_params.logits_allowed_tokens_filter_function = prefix_function\n",
" result = llm.generate(prompt, sampling_params=sampling_params)\n",
" return result[0].outputs[0].text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0e-K6Ij5-mLy"
},
"source": [
"## vLLM + JSON Use case\n",
"\n",
"Now we demonstrate using ```JsonSchemaParser```. We create a pydantic model, generate the schema from it, and use that to enforce the format.\n",
"The output will always be in a format that can be parsed by the parser."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dZSMyRQH-mLy",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 646
},
"outputId": "4785dccb-cae7-4dc8-dbe9-89d4cd3b75d9"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "**Prompt:**"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "```\n<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\nPlease give me information about Michael Jordan. You MUST answer using the following json schema: {\"title\": \"AnswerFormat\", \"type\": \"object\", \"properties\": {\"first_name\": {\"title\": \"First Name\", \"type\": \"string\"}, \"last_name\": {\"title\": \"Last Name\", \"type\": \"string\"}, \"year_of_birth\": {\"title\": \"Year Of Birth\", \"type\": \"integer\"}, \"num_seasons_in_nba\": {\"title\": \"Num Seasons In Nba\", \"type\": \"integer\"}}, \"required\": [\"first_name\", \"last_name\", \"year_of_birth\", \"num_seasons_in_nba\"]} [/INST]\n```"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "**Answer, With json schema enforcing:**"
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:06<00:00, 6.24s/it]\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "```\n {\n\"first_name\": \"Michael\",\n\"last_name\": \"Jordan\",\n\"year_of_birth\": 1963,\n\"num_seasons_in_nba\": 15\n}\n```"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "**Answer, Without json schema enforcing:**"
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:06<00:00, 6.25s/it]\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "```\n Of course! Here's the requested information about Michael Jordan in the format specified:\n\n{\n\"title\": \"AnswerFormat\",\n\"type\": \"object\",\n\"properties\": {\n\"first_name\": {\n\"title\": \"First Name\",\n\"type\": \"string\",\n\"example\": \"Michael\"\n},\n\"last_name\": {\n\"title\": \"Last Name\",\n\"type\": \"string\",\n\"example\": \"J\n```"
},
"metadata": {}
}
],
"source": [
"from lmformatenforcer import JsonSchemaParser\n",
"from pydantic import BaseModel\n",
"\n",
"from typing import List\n",
"\n",
"class AnswerFormat(BaseModel):\n",
" first_name: str\n",
" last_name: str\n",
" year_of_birth: int\n",
" num_seasons_in_nba: int\n",
"\n",
"question = 'Please give me information about Michael Jordan. You MUST answer using the following json schema: '\n",
"question_with_schema = f'{question}{AnswerFormat.schema_json()}'\n",
"prompt = get_prompt(question_with_schema)\n",
"\n",
"display_header(\"Prompt:\")\n",
"display_content(prompt)\n",
"\n",
"display_header(\"Answer, With json schema enforcing:\")\n",
"\n",
"result = vllm_with_character_level_parser(llm, prompt, JsonSchemaParser(AnswerFormat.schema()))\n",
"display_content(result)\n",
"\n",
"display_header(\"Answer, Without json schema enforcing:\")\n",
"result = vllm_with_character_level_parser(llm, prompt, None)\n",
"display_content(result)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A6ozgJyu-mLy"
},
"source": [
"As you can see, the enforced output matches the required schema, while the unenforced does not. We have successfully integrated with vLLM!"
]
},
{
"cell_type": "code",
"source": [
"def batch_vllm_with_character_level_parser(llm: vllm.LLM, prompts: list[str], parser: Optional[CharacterLevelParser] = None) -> list[str]:\n",
" with mock.patch.object(vllm.model_executor.layers.sampler.Sampler, 'forward', patched_forward):\n",
" prefix_function = build_transformers_prefix_allowed_tokens_fn(llm.get_tokenizer(), parser) if parser else None\n",
" sampling_params = SamplingParamsWithFilterFunction()\n",
" sampling_params.max_tokens = DEFAULT_MAX_NEW_TOKENS\n",
" sampling_params.logits_allowed_tokens_filter_function = prefix_function\n",
" results = llm.generate(prompts, sampling_params=sampling_params)\n",
" return [r.outputs[0].text for r in results]"
],
"metadata": {
"id": "WYfMcAUAMo4J"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_prompts(names, AnswerFormat):\n",
" return [f\"Please give me information about {name}. You MUST answer using the following json schema:\\n{AnswerFormat.schema_json()}\" for name in names]\n",
"\n",
"\n",
"names = \\\n",
" ['Michael Jordan',\n",
" 'Babe Ruth',\n",
" 'Muhammad Ali',\n",
" 'Jim Brown',\n",
" 'Wayne Gretzky',\n",
" 'Jesse Owens',\n",
" 'Jim Thorpe',\n",
" 'Willie Mays',\n",
" 'Jack Nicklaus',\n",
" 'Babe Didrikson',\n",
" 'Joe Louis',\n",
" 'Carl Lewis',\n",
" 'Wilt Chamberlain',\n",
" 'Hank Aaron',\n",
" 'Jackie Robinson',\n",
" 'Ted Williams',\n",
" 'Magic Johnson',\n",
" 'Bill Russell',\n",
" 'Martina Navratilova',\n",
" 'Ty Cobb',\n",
" 'Gordie Howe',\n",
" 'Joe DiMaggio',\n",
" 'Jackie Joyner-Kersee',\n",
" 'Sugar Ray Robinson',\n",
" 'Joe Montana',\n",
" 'Kareem Abdul-Jabbar',\n",
" 'Jerry Rice',\n",
" 'Red Grange',\n",
" 'Arnold Palmer',\n",
" 'Larry Bird',\n",
" 'Bobby Orr',\n",
" 'Johnny Unitas',\n",
" 'Mark Spitz',\n",
" 'Lou Gehrig',\n",
" 'Secretariat',\n",
" 'Oscar Robertson',\n",
" 'Mickey Mantle',\n",
" 'Ben Hogan',\n",
" 'Walter Payton',\n",
" 'Lawrence Taylor',\n",
" 'Wilma Rudolph',\n",
" 'Sandy Koufax',\n",
" 'Julius Erving',\n",
" 'Bobby Jones',\n",
" 'Bill Tilden',\n",
" 'Eric Heiden',\n",
" 'Edwin Moses',\n",
" 'Pete Sampras',\n",
" 'O.J. Simpson',\n",
" 'Chris Evert',\n",
" 'Rocky Marciano',\n",
" 'Jack Dempsey',\n",
" 'Rafer Johnson',\n",
" 'Greg Louganis',\n",
" 'Mario Lemieux',\n",
" 'Pete Rose',\n",
" 'Willie Shoemaker',\n",
" 'Elgin Baylor',\n",
" 'Billie Jean King',\n",
" 'Walter Johnson',\n",
" 'Stan Musial',\n",
" 'Jerry West',\n",
" 'Satchel Paige',\n",
" 'Sammy Baugh',\n",
" 'Althea Gibson',\n",
" 'Eddie Arcaro',\n",
" 'Bob Gibson',\n",
" 'Al Oerter',\n",
" 'Bonnie Blair',\n",
" 'Dick Butkus',\n",
" 'Roberto Clemente',\n",
" 'Bo Jackson',\n",
" 'Josh Gibson',\n",
" 'Deion Sanders',\n",
" 'Dan Marino',\n",
" 'Barry Sanders',\n",
" 'Cy Young',\n",
" 'Bob Mathias',\n",
" 'Gale Sayers',\n",
" 'A.J. Foyt',\n",
" 'Jimmy Connors',\n",
" 'Bobby Hull',\n",
" 'Honus Wagner',\n",
" \"Man o' War\",\n",
" 'Maurice Richard',\n",
" 'Otto Graham',\n",
" 'Henry Armstrong',\n",
" 'Joe Namath',\n",
" 'Rogers Hornsby',\n",
" 'Richard Petty',\n",
" 'Bob Beamon',\n",
" 'Mario Andretti',\n",
" 'Don Hutson',\n",
" 'Bob Cousy',\n",
" 'George Blanda',\n",
" 'Michael Johnson',\n",
" 'Citation',\n",
" 'Don Budge',\n",
" 'Sam Snead',\n",
" 'Jack Johnson']\n",
"\n",
"\n",
"def get_players(num_of_names: int) -> list[AnswerFormat | ValueError]:\n",
" prompts = get_prompts(names[:num_of_names], AnswerFormat)\n",
" players_raw = batch_vllm_with_character_level_parser(llm, prompts, JsonSchemaParser(AnswerFormat.schema()))\n",
" players = []\n",
" for p in players_raw:\n",
" try:\n",
" players.append(AnswerFormat.parse_raw(p))\n",
" except ValueError as e:\n",
" players.append(e)\n",
" print()\n",
" print(\"The number of parsed players: \", sum([isinstance(p, AnswerFormat) for p in players]))\n",
" return players"
],
"metadata": {
"id": "ploHfeUzPIVW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"players = get_players(3)\n",
"players"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RiAdvb2XNfta",
"outputId": "4d2cf002-e526-4a02-9972-f5c59cc7cb96"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 3/3 [00:13<00:00, 4.57s/it]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 3\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[AnswerFormat(first_name='Michael', last_name='Jordan', year_of_birth=1963, num_seasons_in_nba=15),\n",
" AnswerFormat(first_name='George', last_name='Herman', year_of_birth=1895, num_seasons_in_nba=20),\n",
" AnswerFormat(first_name='Muhammad', last_name='Ali', year_of_birth=1942, num_seasons_in_nba=56)]"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"source": [
"%%timeit\n",
"\n",
"get_players(1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OsADQytcb9ct",
"outputId": "60a0cadb-c4e2-4be9-9f8b-97cb98647d3f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:04<00:00, 4.42s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 1\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:04<00:00, 4.68s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 1\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:04<00:00, 4.49s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 1\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:03<00:00, 3.53s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 1\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:04<00:00, 4.87s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 1\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:04<00:00, 4.48s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 1\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:04<00:00, 4.39s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 1\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 1/1 [00:04<00:00, 4.75s/it]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 1\n",
"5.72 s ± 575 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%%timeit\n",
"\n",
"get_players(10)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xkxkzrJQcHFn",
"outputId": "8864d2be-ef9a-4651-a250-80190297d045"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 10/10 [00:45<00:00, 4.59s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 10\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 10/10 [00:40<00:00, 4.05s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 10\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 10/10 [00:42<00:00, 4.25s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 10\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 10/10 [00:41<00:00, 4.18s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 10\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 10/10 [00:39<00:00, 3.95s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 10\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 10/10 [00:44<00:00, 4.41s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 10\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 10/10 [00:43<00:00, 4.35s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 10\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 10/10 [00:41<00:00, 4.16s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 10\n",
"43.1 s ± 1.69 s per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%%timeit\n",
"\n",
"get_players(32)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "a9AXVdVRfx6t",
"outputId": "34945566-bda2-497f-cbaf-f8da1a266bdd"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 32/32 [02:17<00:00, 4.29s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 32\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 32/32 [02:18<00:00, 4.32s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 32\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 32/32 [02:14<00:00, 4.20s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 32\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 32/32 [02:11<00:00, 4.11s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 32\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 32/32 [02:18<00:00, 4.32s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 32\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 32/32 [02:18<00:00, 4.34s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 31\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 32/32 [02:17<00:00, 4.29s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 32\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 32/32 [02:14<00:00, 4.20s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 32\n",
"2min 17s ± 2.54 s per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%%timeit\n",
"\n",
"get_players(64)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kHh1iXjdf2te",
"outputId": "37eee25b-a00b-4cf4-aba2-257863accfdc"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 64/64 [04:43<00:00, 4.43s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 64\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 64/64 [04:47<00:00, 4.49s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 64\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 64/64 [04:40<00:00, 4.38s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 64\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 64/64 [04:35<00:00, 4.30s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 64\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 64/64 [04:39<00:00, 4.36s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 64\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 64/64 [04:28<00:00, 4.19s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 64\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 64/64 [04:32<00:00, 4.25s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 64\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 64/64 [04:35<00:00, 4.30s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 64\n",
"4min 38s ± 5.9 s per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%%timeit\n",
"\n",
"get_players(100)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kvPvCfDtf7a_",
"outputId": "99a81df9-3097-4da2-b631-7cbdfb8568d1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 100/100 [07:04<00:00, 4.25s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 99\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 100/100 [06:52<00:00, 4.13s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 100\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 100/100 [07:18<00:00, 4.38s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 100\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 100/100 [07:17<00:00, 4.38s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 100\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 100/100 [07:14<00:00, 4.34s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 100\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 100/100 [07:21<00:00, 4.42s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 100\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 100/100 [07:05<00:00, 4.25s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 100\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Processed prompts: 100%|██████████| 100/100 [06:58<00:00, 4.19s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"The number of parsed players: 100\n",
"7min 11s ± 9.99 s per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"x = [1, 10, 32, 64, 100]\n",
"y = [5.72, 43.1, 137, 278, 431]\n",
"yerr = [.575, 1.69, 2.54, 5.9, 9.99]\n",
"\n",
"plt.plot(x, y)\n",
"plt.errorbar(x=x, y=y, yerr=yerr, fmt ='o')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 447
},
"id": "r250-nFwlXZj",
"outputId": "11a77fd7-41f9-46cd-815d-43a0220a15dc"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<ErrorbarContainer object of 3 artists>"
]
},
"metadata": {},
"execution_count": 18
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
},
"orig_nbformat": 4,
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment