Skip to content

Instantly share code, notes, and snippets.

@jtmuller5
Created September 5, 2024 12:27
Show Gist options
  • Save jtmuller5/dd417ec2a5a678f0d96b570dab9dc515 to your computer and use it in GitHub Desktop.
Save jtmuller5/dd417ec2a5a678f0d96b570dab9dc515 to your computer and use it in GitHub Desktop.
emotional-gemma.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyPphzL3RXvdzH8Lck6WnQKs",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jtmuller5/dd417ec2a5a678f0d96b570dab9dc515/emotional-gemma.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ViIkGiSSNP_N",
"outputId": "1dda0df6-71de-466e-9024-f8d98fa4e9cc"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"source": [
"import json\n",
"\n",
"# Adjust the path as necessary\n",
"with open('/content/drive/MyDrive/formatted_data.json', 'r') as f:\n",
" data = json.load(f)\n",
"\n",
"# If you want to use a subset of the data, you can do:\n",
"data = data[:1000] # Adjust this number based on your dataset size and available resources"
],
"metadata": {
"id": "GZqZgmh5N_9b"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"from google.colab import userdata\n",
"\n",
"# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
"# vars as appropriate for your system.\n",
"\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
],
"metadata": {
"id": "dEXz-I4lOByN"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n",
"!pip install -q -U keras-nlp\n",
"!pip install -q -U \"keras>=3\""
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "O6paASyyONo1",
"outputId": "6323be95-fc9b-4bfa-f520-fbca550743d0"
},
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/572.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m572.2/572.2 kB\u001b[0m \u001b[31m27.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m99.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m37.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"source": [
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"torch\" or \"tensorflow\".\n",
"# Avoid memory fragmentation on JAX backend.\n",
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
],
"metadata": {
"id": "DsjE4ME1OPDW"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import keras\n",
"import keras_nlp"
],
"metadata": {
"id": "pgVp86wAOQTd"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_2b_en\")\n",
"gemma_lm.summary()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 397
},
"id": "m67Ft9qPOQwN",
"outputId": "4bcadb0c-6570-4e22-ab3c-8e6925d6939f"
},
"execution_count": 7,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Preprocessor: \"gemma_causal_lm_preprocessor\"</span>\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n",
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Tokenizer (type) </span>┃<span style=\"font-weight: bold\"> Vocab # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ gemma_tokenizer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaTokenizer</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256,000</span> │\n",
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"gemma_causal_lm\"</span>\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,614,341,888\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n",
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ padding_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_ids (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ gemma_backbone │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">2304</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> │ padding_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaBackbone</span>) │ │ │ token_ids[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_embedding │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256000</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">589,824,000</span> │ gemma_backbone[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ReversibleEmbedding</span>) │ │ │ │\n",
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> (9.74 GB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> (9.74 GB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"def classify_text(text):\n",
" prompt = f\"Instruction: Classify the following text into one of these categories: sadness, joy, love, anger, fear, or surprise.\\n\\nText: {text}\\n\\nResponse: The category for this text is\"\n",
"\n",
" response = gemma_lm.generate(prompt, max_length=256)\n",
" return response\n",
"\n",
"# Example usage\n",
"test_text = \"I hope I hear some news about the company's future soon.\"\n",
"result = classify_text(test_text)\n",
"print(result)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LWRFGzUNOSZU",
"outputId": "4226ef4a-75bc-48bb-af13-60f9af8a5b1e"
},
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Instruction: Classify the following text into one of these categories: sadness, joy, love, anger, fear, or surprise.\n",
"\n",
"Text: I hope I hear some news about the company's future soon.\n",
"\n",
"Response: The category for this text is sadness\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Enable LoRA for the model and set the LoRA rank to 4.\n",
"gemma_lm.backbone.enable_lora(rank=4)\n",
"gemma_lm.summary()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 397
},
"id": "189-mSmTOznv",
"outputId": "9838458e-09ae-4436-e00b-ad2edc53d709"
},
"execution_count": 8,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Preprocessor: \"gemma_causal_lm_preprocessor\"</span>\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n",
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Tokenizer (type) </span>┃<span style=\"font-weight: bold\"> Vocab # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ gemma_tokenizer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaTokenizer</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256,000</span> │\n",
"└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"gemma_causal_lm\"</span>\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,617,270,528\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n",
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ padding_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_ids (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ gemma_backbone │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">2304</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">2,617,270,528</span> │ padding_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaBackbone</span>) │ │ │ token_ids[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
"│ token_embedding │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256000</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">589,824,000</span> │ gemma_backbone[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ReversibleEmbedding</span>) │ │ │ │\n",
"└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,617,270,528\u001b[0m (9.75 GB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,617,270,528</span> (9.75 GB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,928,640</span> (11.17 MB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> (9.74 GB)\n",
"</pre>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"# Limit the input sequence length to 256 (to control memory usage).\n",
"gemma_lm.preprocessor.sequence_length = 256\n",
"# Use AdamW (a common optimizer for transformer models).\n",
"optimizer = keras.optimizers.AdamW(\n",
" learning_rate=5e-5,\n",
" weight_decay=0.01,\n",
")\n",
"# Exclude layernorm and bias terms from decay.\n",
"optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n",
"\n",
"gemma_lm.compile(\n",
" loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" optimizer=optimizer,\n",
" weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
")\n",
"\n",
"def format_prompt(example):\n",
" return f\"Instruction: {example['instruction']}\\n\\nText: {example['input']}\\n\\nResponse: The category for this text is {example['output']}\"\n",
"\n",
"formatted_prompts = [format_prompt(example) for example in data]\n",
"\n",
"gemma_lm.fit(formatted_prompts, epochs=1, batch_size=1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RSuxp46eP96h",
"outputId": "6f75f1ce-c1bf-4e05-c241-a92028bac126"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m863s\u001b[0m 825ms/step - loss: 0.4823 - sparse_categorical_accuracy: 0.6489\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.src.callbacks.history.History at 0x7a5d50119db0>"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"source": [
"# Example usage\n",
"test_text = \"I'm so excited about my upcoming vacation!\"\n",
"result = classify_text(test_text)\n",
"print(result)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 198
},
"id": "0VG630G2QGKI",
"outputId": "1b802f6f-905c-47c1-9f66-03d314c63ceb"
},
"execution_count": 17,
"outputs": [
{
"output_type": "error",
"ename": "NameError",
"evalue": "name 'classify_text' is not defined",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-17-3ecde10d3637>\u001b[0m in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Example usage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mtest_text\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"I'm so excited about my upcoming vacation!\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclassify_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'classify_text' is not defined"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Save the full model weights\n",
"# This takes forever...\n",
"# gemma_lm.save_weights('/content/gemma_finetuned.weights.h5')\n",
"\n",
"# Alternatively, if you only want to save the LoRA weights:\n",
"gemma_lm.backbone.save_lora_weights('/content/gemma_lora.lora.h5')"
],
"metadata": {
"id": "FyDe6bcqVq3d"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from google.colab import files\n",
"\n",
"# files.download('/content/gemma_finetuned.weights.h5')\n",
"# files.download('/content/gemma_model_config.json')\n",
"# If you saved only LoRA weights:\n",
"files.download('/content/gemma_lora.lora.h5')"
],
"metadata": {
"id": "fikokE6Fb-7w",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "47b16907-aaa6-4608-dd3c-246e0e4fdf8a"
},
"execution_count": 11,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"\n",
" async function download(id, filename, size) {\n",
" if (!google.colab.kernel.accessAllowed) {\n",
" return;\n",
" }\n",
" const div = document.createElement('div');\n",
" const label = document.createElement('label');\n",
" label.textContent = `Downloading \"${filename}\": `;\n",
" div.appendChild(label);\n",
" const progress = document.createElement('progress');\n",
" progress.max = size;\n",
" div.appendChild(progress);\n",
" document.body.appendChild(div);\n",
"\n",
" const buffers = [];\n",
" let downloaded = 0;\n",
"\n",
" const channel = await google.colab.kernel.comms.open(id);\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
"\n",
" for await (const message of channel.messages) {\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
" if (message.buffers) {\n",
" for (const buffer of message.buffers) {\n",
" buffers.push(buffer);\n",
" downloaded += buffer.byteLength;\n",
" progress.value = downloaded;\n",
" }\n",
" }\n",
" }\n",
" const blob = new Blob(buffers, {type: 'application/binary'});\n",
" const a = document.createElement('a');\n",
" a.href = window.URL.createObjectURL(blob);\n",
" a.download = filename;\n",
" div.appendChild(a);\n",
" a.click();\n",
" div.remove();\n",
" }\n",
" "
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"download(\"download_e4104eb1-5963-4de3-84d7-4556539abf88\", \"gemma_finetuned.weights.h5\", 10481808632)"
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"\n",
"# !cp /content/gemma_finetuned.weights.h5 /content/drive/MyDrive/\n",
"# !cp /content/gemma_model_config.json /content/drive/MyDrive/\n",
"# If you saved only LoRA weights:\n",
"!cp /content/gemma_lora.lora.h5 /content/drive/MyDrive/"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "I9jmPuM_EBZ4",
"outputId": "319320ce-c96a-4d20-d3ba-be351a6bf0f3"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment