Skip to content

Instantly share code, notes, and snippets.

@eramax
Created February 23, 2024 15:27
Show Gist options
  • Save eramax/cf7b3af1982c0fe2850e0c5989bc9563 to your computer and use it in GitHub Desktop.
Save eramax/cf7b3af1982c0fe2850e0c5989bc9563 to your computer and use it in GitHub Desktop.
gemma-TPU.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/eramax/cf7b3af1982c0fe2850e0c5989bc9563/gemma-tpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install --upgrade transformers"
],
"metadata": {
"id": "r5mij3MdmtSo",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4b544c6d-4ec2-4d3b-bf32-e7d962bc4c39"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.38.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.12.25)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.2)\n",
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.2)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (2023.6.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (4.9.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "LjQ4e57vmWCP",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "9dc207c0-c6c1-44c0-91bc-2169121057d1"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
" TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
" TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
" TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
" TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
" TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
" TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
" TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
]
},
"metadata": {},
"execution_count": 2
}
],
"source": [
"import time\n",
"import os\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"from flax import jax_utils\n",
"from flax.training.common_utils import shard\n",
"import jax.tools.colab_tpu\n",
"jax.tools.colab_tpu.setup_tpu()\n",
"from transformers import FlaxGemmaForCausalLM, AutoTokenizer\n",
"jax.devices()"
]
},
{
"cell_type": "code",
"source": [
"model_name = \"google/gemma-2b-it\"\n",
"model, params = FlaxGemmaForCausalLM.from_pretrained(model_name, revision=\"flax\", _do_init=False, dtype=jnp.bfloat16)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"params = jax_utils.replicate(params)\n",
"max_new_tokens = 1024"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "keuwma2emdRM",
"outputId": "3d45faa0-7aa7-463a-b2ed-afcbfd87ffe6"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Some of the weights of FlaxGemmaForCausalLM were initialized in bfloat16 precision from the model checkpoint at google/gemma-2b-it:\n",
"[('model', 'embed_tokens', 'embedding'), ('model', 'layers', '0', 'input_layernorm', 'weight'), ('model', 'layers', '0', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '0', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '0', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '0', 'post_attention_layernorm', 'weight'), ('model', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '0', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '1', 'input_layernorm', 'weight'), ('model', 'layers', '1', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '1', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '1', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '1', 'post_attention_layernorm', 'weight'), ('model', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '1', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '10', 'input_layernorm', 'weight'), ('model', 'layers', '10', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '10', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '10', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '10', 'post_attention_layernorm', 'weight'), ('model', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '10', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '11', 'input_layernorm', 'weight'), ('model', 'layers', '11', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '11', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '11', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '11', 'post_attention_layernorm', 'weight'), ('model', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '11', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '12', 'input_layernorm', 'weight'), ('model', 'layers', '12', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '12', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '12', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '12', 'post_attention_layernorm', 'weight'), ('model', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '12', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '13', 'input_layernorm', 'weight'), ('model', 'layers', '13', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '13', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '13', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '13', 'post_attention_layernorm', 'weight'), ('model', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '13', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '14', 'input_layernorm', 'weight'), ('model', 'layers', '14', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '14', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '14', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '14', 'post_attention_layernorm', 'weight'), ('model', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '14', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '15', 'input_layernorm', 'weight'), ('model', 'layers', '15', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '15', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '15', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '15', 'post_attention_layernorm', 'weight'), ('model', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '15', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '16', 'input_layernorm', 'weight'), ('model', 'layers', '16', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '16', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '16', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '16', 'post_attention_layernorm', 'weight'), ('model', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '16', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '17', 'input_layernorm', 'weight'), ('model', 'layers', '17', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '17', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '17', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '17', 'post_attention_layernorm', 'weight'), ('model', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '17', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '2', 'input_layernorm', 'weight'), ('model', 'layers', '2', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '2', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '2', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '2', 'post_attention_layernorm', 'weight'), ('model', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '2', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '3', 'input_layernorm', 'weight'), ('model', 'layers', '3', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '3', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '3', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '3', 'post_attention_layernorm', 'weight'), ('model', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '3', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '4', 'input_layernorm', 'weight'), ('model', 'layers', '4', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '4', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '4', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '4', 'post_attention_layernorm', 'weight'), ('model', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '4', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '5', 'input_layernorm', 'weight'), ('model', 'layers', '5', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '5', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '5', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '5', 'post_attention_layernorm', 'weight'), ('model', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '5', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '6', 'input_layernorm', 'weight'), ('model', 'layers', '6', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '6', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '6', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '6', 'post_attention_layernorm', 'weight'), ('model', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '6', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '7', 'input_layernorm', 'weight'), ('model', 'layers', '7', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '7', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '7', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '7', 'post_attention_layernorm', 'weight'), ('model', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '7', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '8', 'input_layernorm', 'weight'), ('model', 'layers', '8', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '8', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '8', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '8', 'post_attention_layernorm', 'weight'), ('model', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '8', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '9', 'input_layernorm', 'weight'), ('model', 'layers', '9', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '9', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '9', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '9', 'post_attention_layernorm', 'weight'), ('model', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '9', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('model', 'norm', 'weight')]\n",
"You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"The corresponding tokenizer can now be loaded using a similar API:"
],
"metadata": {
"id": "GQMjMGOLeHF_"
}
},
{
"cell_type": "code",
"source": [
"def generate(inputs, params, max_new_tokens):\n",
" generated_ids = model.generate(\n",
" inputs[\"input_ids\"],\n",
" attention_mask=inputs[\"attention_mask\"],\n",
" params=params,\n",
" max_new_tokens=max_new_tokens,\n",
" do_sample=True,\n",
" )\n",
" return generated_ids.sequences\n",
"\n",
"p_generate = jax.pmap(\n",
" generate, \"inputs\", in_axes=(0, 0, None,), out_axes=0, static_broadcasted_argnums=(2,)\n",
")\n",
"\n",
"def compute_tok_per_s(input_ids, generated_ids, runtime):\n",
" total_inputs = np.prod(input_ids.shape)\n",
" total_outputs = np.prod(generated_ids.shape)\n",
" tokens_generated = total_outputs - total_inputs\n",
" tokens_per_s = tokens_generated / runtime\n",
" return tokens_per_s\n",
"\n",
"def gpt(x):\n",
" input_text = 8 * [x]\n",
" inputs = tokenizer(input_text, padding=\"max_length\", max_length=max_new_tokens, return_attention_mask=True, return_tensors=\"np\")\n",
" inputs = shard(inputs.data)\n",
"\n",
" start = time.time()\n",
" generated_ids = p_generate(inputs, params, max_new_tokens)\n",
" runtime = time.time() - start\n",
"\n",
" generated_ids = jax.device_get(generated_ids.reshape(-1, generated_ids.shape[-1]))\n",
" pred_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
" tok_per_s = compute_tok_per_s(inputs[\"input_ids\"], generated_ids, runtime)\n",
"\n",
" print(f\"Runtime with pmap: {runtime}\")\n",
" print(f\"Tokens per second: {tok_per_s}\")\n",
" print(pred_text[0])"
],
"metadata": {
"id": "SOgY8-DroH4y"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"gpt(\"Hi\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nKewN_fbZwSs",
"outputId": "61f38285-3ba3-492a-bc12-0260e0dbb7cb"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=bfloat16 to dtype=float32. In future JAX releases this will result in an error.\n",
" warnings.warn(\"scatter inputs have incompatible types: cannot safely cast \"\n",
"/usr/local/lib/python3.10/dist-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=bfloat16 to dtype=float32. In future JAX releases this will result in an error.\n",
" warnings.warn(\"scatter inputs have incompatible types: cannot safely cast \"\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Runtime with pmap: 124.53560709953308\n",
"Tokens per second: 65.78038354486581\n",
"Hi there,\n",
"\n",
"I'm looking to get a new phone and I'm considering the options on your website. Could you please provide me with some information about your pricing and plans?\n",
"\n",
"I'm particularly interested in your [product category or specific features].\n",
"\n",
"I would also appreciate it if you could answer any questions I have about your products or services.\n",
"\n",
"Thank you for your time and help!\n",
"\n",
"I'm looking to get a new phone and I'm considering the options on your website. Could you please provide me with some information about your pricing and plans?\n",
"\n",
"**Pricing and Plans**\n",
"\n",
"**Product Categories:**\n",
"\n",
"* Tablets\n",
"* Smartphones\n",
"* Laptops\n",
"* Headphones\n",
"\n",
"**Plans:**\n",
"\n",
"**Basic:**\n",
"\n",
"* Starts at $[price of the device] per month\n",
"* Includes basic features such as touchscreen, internet access, and a camera\n",
"\n",
"**Standard:**\n",
"\n",
"* Starts at $[price of the device] per month\n",
"* Includes standard features such as a larger display, more powerful processor, and a higher-resolution camera\n",
"\n",
"**Premium:**\n",
"\n",
"* Starts at $[price of the device] per month\n",
"* Includes premium features such as a premium display, a high-end processor, and a luxurious camera\n",
"\n",
"**Features:**\n",
"\n",
"* **Display:** LCD or OLED screen, with a resolution of 1280 x 1920 pixels\n",
"* **Processor:** Snapdragon, Apple A15 Bionic, or Qualcomm Snapdragon\n",
"* **Camera:** Rear camera with a resolution of 13MP, front camera with a resolution of 8MP\n",
"* **Storage:** 64GB, 128GB, or 256GB\n",
"* **Operating System:** Android or iOS\n",
"\n",
"**Questions:**\n",
"\n",
"* What is your return policy?\n",
"* What are your shipping options?\n",
"* How can I contact customer support?\n",
"\n",
"**Conclusion:**\n",
"\n",
"To find the best phone that fits your needs, browse our website and compare our pricing and plans. You can also read reviews and compare products side-by-side to make an informed decision.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"gpt(\"explain quick sort by example\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "j2jeU0slkloU",
"outputId": "a753f042-f714-4909-dfbb-a4d17c265ada"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Runtime with pmap: 8.44529914855957\n",
"Tokens per second: 970.0070839287234\n",
"explain quick sort by example.\n",
"\n",
"**Quick Sort Algorithm:**\n",
"\n",
"* **Choose a pivot:** Select the first element of the input list as the pivot.\n",
"* **Partition the list:** Split the list into two sublists: one containing elements less than the pivot and another containing elements greater than the pivot.\n",
"* **Repeat:** Recursively sort the sublists until they are sorted, and then merge the two sorted sublists into the original list.\n",
"\n",
"**Example:**\n",
"\n",
"```python\n",
"def quick_sort(arr):\n",
" if len(arr) <= 1:\n",
" return arr\n",
" pivot = arr[0]\n",
" left = []\n",
" right = []\n",
" for elem in arr:\n",
" if elem < pivot:\n",
" left.append(elem)\n",
" else:\n",
" right.append(elem)\n",
" return quick_sort(left) + [pivot] + quick_sort(right)\n",
"\n",
"\n",
"arr = [8, 7, 9, 1, 5]\n",
"print(quick_sort(arr))\n",
"```\n",
"\n",
"**Output:**\n",
"\n",
"```\n",
"[1, 5, 7, 8, 9]\n",
"```\n",
"\n",
"**Explanation:**\n",
"\n",
"1. The `quick_sort` function takes a list `arr` as input.\n",
"2. It starts by checking if the length of the list is 1 or less. If it is, it is already sorted, so it returns the list as is.\n",
"3. If the list has more than one element, it chooses the first element as the pivot.\n",
"4. It then creates two sublists: `left` for elements less than the pivot and `right` for elements greater than the pivot.\n",
"5. The function recursively sorts the `left` and `right` sublists.\n",
"6. Finally, it merges the two sorted sublists back into the original `arr` list.\n",
"\n",
"**Time Complexity:**\n",
"\n",
"The time complexity of quick sort is O(n log n), where n is the length of the input list. This is because the algorithm repeatedly divides the list into smaller and smaller sublists until each sublist contains only one element.\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment