Created
September 5, 2024 12:27
-
-
Save jtmuller5/dd417ec2a5a678f0d96b570dab9dc515 to your computer and use it in GitHub Desktop.
emotional-gemma.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"authorship_tag": "ABX9TyPphzL3RXvdzH8Lck6WnQKs", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/jtmuller5/dd417ec2a5a678f0d96b570dab9dc515/emotional-gemma.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ViIkGiSSNP_N", | |
"outputId": "1dda0df6-71de-466e-9024-f8d98fa4e9cc" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Mounted at /content/drive\n" | |
] | |
} | |
], | |
"source": [ | |
"from google.colab import drive\n", | |
"drive.mount('/content/drive')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import json\n", | |
"\n", | |
"# Adjust the path as necessary\n", | |
"with open('/content/drive/MyDrive/formatted_data.json', 'r') as f:\n", | |
" data = json.load(f)\n", | |
"\n", | |
"# If you want to use a subset of the data, you can do:\n", | |
"data = data[:1000] # Adjust this number based on your dataset size and available resources" | |
], | |
"metadata": { | |
"id": "GZqZgmh5N_9b" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"from google.colab import userdata\n", | |
"\n", | |
"# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", | |
"# vars as appropriate for your system.\n", | |
"\n", | |
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", | |
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')" | |
], | |
"metadata": { | |
"id": "dEXz-I4lOByN" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n", | |
"!pip install -q -U keras-nlp\n", | |
"!pip install -q -U \"keras>=3\"" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "O6paASyyONo1", | |
"outputId": "6323be95-fc9b-4bfa-f520-fbca550743d0" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/572.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m572.2/572.2 kB\u001b[0m \u001b[31m27.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m99.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m37.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"torch\" or \"tensorflow\".\n", | |
"# Avoid memory fragmentation on JAX backend.\n", | |
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\"" | |
], | |
"metadata": { | |
"id": "DsjE4ME1OPDW" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import keras\n", | |
"import keras_nlp" | |
], | |
"metadata": { | |
"id": "pgVp86wAOQTd" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_2b_en\")\n", | |
"gemma_lm.summary()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 397 | |
}, | |
"id": "m67Ft9qPOQwN", | |
"outputId": "4bcadb0c-6570-4e22-ab3c-8e6925d6939f" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Preprocessor: \"gemma_causal_lm_preprocessor\"</span>\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", | |
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃<span style=\"font-weight: bold\"> Tokenizer (type) </span>┃<span style=\"font-weight: bold\"> Vocab # </span>┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ gemma_tokenizer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaTokenizer</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256,000</span> │\n", | |
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"gemma_causal_lm\"</span>\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,614,341,888\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", | |
"│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", | |
"│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", | |
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ padding_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ token_ids (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ gemma_backbone │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">2304</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> │ padding_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n", | |
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaBackbone</span>) │ │ │ token_ids[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ token_embedding │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256000</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">589,824,000</span> │ gemma_backbone[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", | |
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ReversibleEmbedding</span>) │ │ │ │\n", | |
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> (9.74 GB)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> (9.74 GB)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def classify_text(text):\n", | |
" prompt = f\"Instruction: Classify the following text into one of these categories: sadness, joy, love, anger, fear, or surprise.\\n\\nText: {text}\\n\\nResponse: The category for this text is\"\n", | |
"\n", | |
" response = gemma_lm.generate(prompt, max_length=256)\n", | |
" return response\n", | |
"\n", | |
"# Example usage\n", | |
"test_text = \"I hope I hear some news about the company's future soon.\"\n", | |
"result = classify_text(test_text)\n", | |
"print(result)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "LWRFGzUNOSZU", | |
"outputId": "4226ef4a-75bc-48bb-af13-60f9af8a5b1e" | |
}, | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Instruction: Classify the following text into one of these categories: sadness, joy, love, anger, fear, or surprise.\n", | |
"\n", | |
"Text: I hope I hear some news about the company's future soon.\n", | |
"\n", | |
"Response: The category for this text is sadness\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Enable LoRA for the model and set the LoRA rank to 4.\n", | |
"gemma_lm.backbone.enable_lora(rank=4)\n", | |
"gemma_lm.summary()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 397 | |
}, | |
"id": "189-mSmTOznv", | |
"outputId": "9838458e-09ae-4436-e00b-ad2edc53d709" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Preprocessor: \"gemma_causal_lm_preprocessor\"</span>\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", | |
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃<span style=\"font-weight: bold\"> Tokenizer (type) </span>┃<span style=\"font-weight: bold\"> Vocab # </span>┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ gemma_tokenizer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaTokenizer</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256,000</span> │\n", | |
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"gemma_causal_lm\"</span>\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,617,270,528\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", | |
"│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", | |
"│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", | |
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ padding_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ token_ids (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ gemma_backbone │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">2304</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">2,617,270,528</span> │ padding_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n", | |
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaBackbone</span>) │ │ │ token_ids[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", | |
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", | |
"│ token_embedding │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256000</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">589,824,000</span> │ gemma_backbone[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", | |
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ReversibleEmbedding</span>) │ │ │ │\n", | |
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,617,270,528\u001b[0m (9.75 GB)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,617,270,528</span> (9.75 GB)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,928,640</span> (11.17 MB)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> (9.74 GB)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Limit the input sequence length to 256 (to control memory usage).\n", | |
"gemma_lm.preprocessor.sequence_length = 256\n", | |
"# Use AdamW (a common optimizer for transformer models).\n", | |
"optimizer = keras.optimizers.AdamW(\n", | |
" learning_rate=5e-5,\n", | |
" weight_decay=0.01,\n", | |
")\n", | |
"# Exclude layernorm and bias terms from decay.\n", | |
"optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", | |
"\n", | |
"gemma_lm.compile(\n", | |
" loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", | |
" optimizer=optimizer,\n", | |
" weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", | |
")\n", | |
"\n", | |
"def format_prompt(example):\n", | |
" return f\"Instruction: {example['instruction']}\\n\\nText: {example['input']}\\n\\nResponse: The category for this text is {example['output']}\"\n", | |
"\n", | |
"formatted_prompts = [format_prompt(example) for example in data]\n", | |
"\n", | |
"gemma_lm.fit(formatted_prompts, epochs=1, batch_size=1)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "RSuxp46eP96h", | |
"outputId": "6f75f1ce-c1bf-4e05-c241-a92028bac126" | |
}, | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m863s\u001b[0m 825ms/step - loss: 0.4823 - sparse_categorical_accuracy: 0.6489\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<keras.src.callbacks.history.History at 0x7a5d50119db0>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 9 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Example usage\n", | |
"test_text = \"I'm so excited about my upcoming vacation!\"\n", | |
"result = classify_text(test_text)\n", | |
"print(result)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 198 | |
}, | |
"id": "0VG630G2QGKI", | |
"outputId": "1b802f6f-905c-47c1-9f66-03d314c63ceb" | |
}, | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "NameError", | |
"evalue": "name 'classify_text' is not defined", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-17-3ecde10d3637>\u001b[0m in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Example usage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mtest_text\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"I'm so excited about my upcoming vacation!\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclassify_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mNameError\u001b[0m: name 'classify_text' is not defined" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Save the full model weights\n", | |
"# This takes forever...\n", | |
"# gemma_lm.save_weights('/content/gemma_finetuned.weights.h5')\n", | |
"\n", | |
"# Alternatively, if you only want to save the LoRA weights:\n", | |
"gemma_lm.backbone.save_lora_weights('/content/gemma_lora.lora.h5')" | |
], | |
"metadata": { | |
"id": "FyDe6bcqVq3d" | |
}, | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from google.colab import files\n", | |
"\n", | |
"# files.download('/content/gemma_finetuned.weights.h5')\n", | |
"# files.download('/content/gemma_model_config.json')\n", | |
"# If you saved only LoRA weights:\n", | |
"files.download('/content/gemma_lora.lora.h5')" | |
], | |
"metadata": { | |
"id": "fikokE6Fb-7w", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "47b16907-aaa6-4608-dd3c-246e0e4fdf8a" | |
}, | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.Javascript object>" | |
], | |
"application/javascript": [ | |
"\n", | |
" async function download(id, filename, size) {\n", | |
" if (!google.colab.kernel.accessAllowed) {\n", | |
" return;\n", | |
" }\n", | |
" const div = document.createElement('div');\n", | |
" const label = document.createElement('label');\n", | |
" label.textContent = `Downloading \"${filename}\": `;\n", | |
" div.appendChild(label);\n", | |
" const progress = document.createElement('progress');\n", | |
" progress.max = size;\n", | |
" div.appendChild(progress);\n", | |
" document.body.appendChild(div);\n", | |
"\n", | |
" const buffers = [];\n", | |
" let downloaded = 0;\n", | |
"\n", | |
" const channel = await google.colab.kernel.comms.open(id);\n", | |
" // Send a message to notify the kernel that we're ready.\n", | |
" channel.send({})\n", | |
"\n", | |
" for await (const message of channel.messages) {\n", | |
" // Send a message to notify the kernel that we're ready.\n", | |
" channel.send({})\n", | |
" if (message.buffers) {\n", | |
" for (const buffer of message.buffers) {\n", | |
" buffers.push(buffer);\n", | |
" downloaded += buffer.byteLength;\n", | |
" progress.value = downloaded;\n", | |
" }\n", | |
" }\n", | |
" }\n", | |
" const blob = new Blob(buffers, {type: 'application/binary'});\n", | |
" const a = document.createElement('a');\n", | |
" a.href = window.URL.createObjectURL(blob);\n", | |
" a.download = filename;\n", | |
" div.appendChild(a);\n", | |
" a.click();\n", | |
" div.remove();\n", | |
" }\n", | |
" " | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.Javascript object>" | |
], | |
"application/javascript": [ | |
"download(\"download_e4104eb1-5963-4de3-84d7-4556539abf88\", \"gemma_finetuned.weights.h5\", 10481808632)" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from google.colab import drive\n", | |
"drive.mount('/content/drive')\n", | |
"\n", | |
"# !cp /content/gemma_finetuned.weights.h5 /content/drive/MyDrive/\n", | |
"# !cp /content/gemma_model_config.json /content/drive/MyDrive/\n", | |
"# If you saved only LoRA weights:\n", | |
"!cp /content/gemma_lora.lora.h5 /content/drive/MyDrive/" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "I9jmPuM_EBZ4", | |
"outputId": "319320ce-c96a-4d20-d3ba-be351a6bf0f3" | |
}, | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment