Skip to content

Instantly share code, notes, and snippets.

@vukrosic
Last active September 9, 2025 16:43
Show Gist options
  • Save vukrosic/94dc965a22b0892042f44fed25918598 to your computer and use it in GitHub Desktop.
Save vukrosic/94dc965a22b0892042f44fed25918598 to your computer and use it in GitHub Desktop.
qwen3_from_scratch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/vukrosic/94dc965a22b0892042f44fed25918598/qwen3_from_scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"- 📺 YouTube video: https://youtu.be/wM-KP_wNAeY\n",
"\n",
"- 📹 Bilibili: https://www.bilibili.com/video/BV1P9tizcEKD/"
],
"metadata": {
"id": "Q8fwaTdY95pg"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "8_5I2up39nnO"
},
"source": [
" # Building Qwen3 from Scratch: A Complete Tutorial\n",
"\n",
"\n",
"\n",
" Welcome to this comprehensive tutorial where we'll build a Qwen3-style language model from scratch!\n",
"\n",
" We'll implement Grouped-Query Attention (GQA), RMSNorm, SwiGLU activations, and add our own spin - new Muon optimizer that accelerates training by 30% to 50%.\n",
"\n",
"\n",
"\n",
" **What you'll learn:**\n",
"\n",
"🧠 Modern Transformer architecture with Qwen3-style features\n",
"\n",
"⚙️ Grouped-Query Attention (GQA) for memory and compute efficiency\n",
"\n",
"💡 Rotary Positional Embeddings (RoPE) for better performance and context window extrapolation\n",
"\n",
"📐 QK-Norm with RMSNorm for improved numerical / training stability\n",
"\n",
"🚀 Muon optimizer using Newton-Schulz orthogonalization for better weight updates, faster learning with less data\n",
"\n",
"🔁 Hybrid optimization using Muon for matrices and AdamW for other parameters\n",
"\n",
"🔄 SwiGLU activation and deep residual learning in the feedforward layers\n",
"\n",
"🔢 Efficient dataset tokenization and caching with HuggingFace Datasets and Transformers\n",
"\n",
"🧪 Validation metrics including loss, accuracy, and perplexity\n",
"\n",
"🧵 Gradient accumulation + AMP (Automatic Mixed Precision) training for larger batch sizes\n",
"\n",
"🎛️ Cosine learning rate scheduling with warmup"
]
},
{
"cell_type": "markdown",
"source": [
"### 📌 Focus: **Qwen-Specific Architecture**\n",
"\n",
"Before diving deep into Qwen, it's **highly recommended** to build a solid foundation in the attention mechanism and tokenization.\n",
"\n",
"Here are two excellent resources to get you started:\n",
"\n",
"---\n",
"\n",
"🎓 **[🦙 LLaMA 4 From Scratch (first 2h 30min)](https://youtu.be/wcDV3l4CD14)**\n",
"\n",
"> In the first **2h 30min**, I give a **clear and intuitive explanation** of both:\n",
"\n",
"* 🧠 Attention Mechanism\n",
"* 🧩 Tokens & Tokenizers\n",
"\n",
"Highly recommended if you're just starting out or want a solid refresher.\n",
"\n",
"---\n",
"\n",
"🎥 **[📘 GPT From Scratch by Andrej Karpathy](https://youtu.be/kCc8FmEb1nY)**\n",
"\n",
"> A legendary tutorial from Karpathy that walks through building a GPT model from scratch. Great for understanding the fundamentals!"
],
"metadata": {
"id": "Drk1IHElC2H6"
}
},
{
"cell_type": "markdown",
"source": [
"💡 If this notebook asks you for a HF_TOKEN (Hugging Face token to download tokenizer or data), you can just hit Cancel — it's not actually needed.\n",
"\n",
"🤷‍♂️ Not sure how to disable it!"
],
"metadata": {
"id": "sfBVv2SYGGUr"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "J8e5ds9P9nnS"
},
"source": [
" ## 1. Setup and Imports\n",
"\n",
"\n",
"\n",
" First, let's import all the necessary libraries. We'll use PyTorch for the deep learning framework,\n",
"\n",
" transformers for tokenization, and various utilities for data handling and training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rQLuHlbT9nnT"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn # Neural network modules like Linear, Embedding, etc.\n",
"import torch.nn.functional as F # Functional interface for operations like cross_entropy, silu, etc.\n",
"from torch.utils.data import Dataset, DataLoader # Base class and utilities for loading datasets\n",
"from torch.cuda.amp import autocast, GradScaler # 🔄 Automatic Mixed Precision (AMP) tools for faster/lower-memory training\n",
"\n",
"import math # Standard math operations (e.g. sqrt, exp, cos)\n",
"import random # Python's random number utilities (used for seeding)\n",
"import numpy as np # Numerical computing library, used for random seeding and general array ops\n",
"\n",
"from datasets import load_dataset # 🧁 Hugging Face Datasets library for streaming large datasets\n",
"from tqdm import tqdm # ⏳ Progress bar visualization library, great for loops\n",
"\n",
"import time # ⌛ Timing utilities, measuring time\n",
"from transformers import AutoTokenizer # 🤗 Load pretrained tokenizers from HuggingFace with one line\n",
"\n",
"from dataclasses import dataclass # 🧱 Define simple classes for configs with less boilerplate\n",
"from typing import List, Optional # ✍️ Type hints for better readability and tooling\n",
"\n",
"import warnings # ⚠️ Suppress or handle warnings\n",
"import os # 🗂️ File system operations (creating folders, path checking, etc.)\n",
"import pickle # 💾 Python object serialization (used to save/load preprocessed datasets)\n",
"\n",
"warnings.filterwarnings('ignore') # Silences warnings for cleaner outputs during training\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VD8LriEQ9nnU"
},
"source": [
" ## 2. Utility Functions\n",
"\n",
"\n",
"\n",
" Let's start with some utility functions for reproducibility and configuration management.\n",
"\n",
" The `set_seed` function ensures our experiments are reproducible."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3CczY5ms9nnU"
},
"outputs": [],
"source": [
"def set_seed(seed: int = 42):\n",
" \"\"\"Set all random seeds for reproducibility\"\"\"\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
" torch.backends.cudnn.deterministic = True\n",
" torch.backends.cudnn.benchmark = False\n",
" print(f\"🌱 Set all seeds to {seed}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j-5evpnH9nnU"
},
"source": [
" ## 3. Model Configuration\n",
"\n",
"\n",
"\n",
" Here we define our model configuration using a dataclass. This makes it easy to experiment\n",
"\n",
" with different model sizes and hyperparameters. Our model will be a smaller version of Qwen3\n",
"\n",
" with 384 dimensions, 6 layers, and 8 attention heads."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Buy5jS939nnV"
},
"outputs": [],
"source": [
"@dataclass\n",
"class ModelConfig:\n",
" # Model architecture\n",
" d_model: int = 384\n",
" n_heads: int = 8\n",
" n_layers: int = 6\n",
" d_ff: int = 1536\n",
" batch_size: int = 24\n",
" max_steps: int = 2000\n",
"\n",
" # Qwen3-like parameters\n",
" n_kv_heads: int = 4 # For Grouped-Query Attention\n",
" sliding_window: int = 4096 # Set a large default, effectively disabling it unless specified\n",
" attention_bias: bool = False # Qwen3 often sets this to False\n",
" rms_norm_eps: float = 1e-6 # Epsilon for RMSNorm\n",
"\n",
" # Training parameters\n",
" gradient_accumulation_steps: int = 4\n",
" muon_lr: float = 0.01\n",
"\n",
" # Data parameters\n",
" max_seq_len: int = 512\n",
" num_documents: int = 2000\n",
" max_tokens: int = 500000\n",
"\n",
" # Evaluation\n",
" eval_every: int = 500\n",
" eval_steps: int = 100\n",
"\n",
" # Regularization\n",
" weight_decay: float = 0.1\n",
" dropout: float = 0.1\n",
" grad_clip: float = 1.0\n",
"\n",
" # Technical\n",
" use_amp: bool = True\n",
" vocab_size: Optional[int] = None\n",
"\n",
" def __post_init__(self):\n",
" self.d_k = self.d_model // self.n_heads\n",
" assert self.d_model % self.n_heads == 0, \"d_model must be divisible by n_heads\"\n",
" assert self.n_heads % self.n_kv_heads == 0, \"n_heads must be divisible by n_kv_heads\"\n",
" self.n_kv_groups = self.n_heads // self.n_kv_heads\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z7rRp5Ac9nnV"
},
"source": [
" ## 4. Grouped-Query Attention Helper\n",
"\n",
"\n",
"\n",
" This function implements the key component of GQA - repeating key and value heads.\n",
"\n",
" In GQA, we have fewer key-value heads than query heads, which reduces memory usage\n",
"\n",
" while maintaining performance."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yucDeNey9nnV"
},
"outputs": [],
"source": [
"def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n",
" \"\"\"\n",
" This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).\n",
" The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)\n",
" to (batch, num_attention_heads, seqlen, head_dim)\n",
" \"\"\"\n",
" # Extract dimensions from input tensor\n",
" batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n",
"\n",
" # Early return if no repetition is needed\n",
" if n_rep == 1:\n",
" return hidden_states\n",
"\n",
" # Add a new dimension at index 2 (after num_key_value_heads) and expand\n",
" # Shape transformation:\n",
" # (batch, num_key_value_heads, slen, head_dim)\n",
" # -> (batch, num_key_value_heads, 1, slen, head_dim) [via None indexing]\n",
" # -> (batch, num_key_value_heads, n_rep, slen, head_dim) [via expand]\n",
" hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n",
"\n",
" # Flatten the num_key_value_heads and n_rep dimensions together\n",
" # Final shape: (batch, num_key_value_heads * n_rep, slen, head_dim)\n",
" # This effectively repeats each key/value head n_rep times\n",
" return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)"
]
},
{
"cell_type": "markdown",
"source": [
"#### Optional: Code excercises to understand `repeat_kv`"
],
"metadata": {
"id": "ZH1RhXtTSyvG"
}
},
{
"cell_type": "code",
"source": [
"# Tensor Expansion and Repetition Practice Exercises\n",
"\n",
"import torch\n",
"import numpy as np\n",
"\n",
"print(\"🚀 Tensor Expansion and Repetition Exercises\")\n",
"print(\"=\" * 50)\n",
"\n",
"# =============================================================================\n",
"# EXERCISE 1: Basic Tensor Creation and Shapes\n",
"# =============================================================================\n",
"print(\"\\n📝 Exercise 1: Understanding Tensor Shapes\")\n",
"print(\"-\" * 40)\n",
"\n",
"# Create simple tensors and understand their shapes\n",
"x = torch.tensor([1, 2, 3])\n",
"print(f\"1D tensor: {x}\")\n",
"print(f\"Shape: {x.shape}\")\n",
"\n",
"y = torch.tensor([[1, 2, 3], [4, 5, 6]])\n",
"print(f\"2D tensor:\\n{y}\")\n",
"print(f\"Shape: {y.shape}\")\n",
"\n",
"z = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])\n",
"print(f\"3D tensor:\\n{z}\")\n",
"print(f\"Shape: {z.shape}\")\n",
"\n",
"# TODO: Create a 4D tensor of shape (2, 3, 4, 5) filled with ones\n",
"# Your code here:\n",
"# tensor_4d = ?"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gsK-daLwS-NP",
"outputId": "3a141409-d2d5-4845-9130-e17e5470b730"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"🚀 Tensor Expansion and Repetition Exercises\n",
"==================================================\n",
"\n",
"📝 Exercise 1: Understanding Tensor Shapes\n",
"----------------------------------------\n",
"1D tensor: tensor([1, 2, 3])\n",
"Shape: torch.Size([3])\n",
"2D tensor:\n",
"tensor([[1, 2, 3],\n",
" [4, 5, 6]])\n",
"Shape: torch.Size([2, 3])\n",
"3D tensor:\n",
"tensor([[[1, 2],\n",
" [3, 4]],\n",
"\n",
" [[5, 6],\n",
" [7, 8]]])\n",
"Shape: torch.Size([2, 2, 2])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# =============================================================================\n",
"# EXERCISE 2: Understanding None Indexing (Adding Dimensions)\n",
"# =============================================================================\n",
"print(\"\\n\\n📝 Exercise 2: Adding Dimensions with None\")\n",
"print(\"-\" * 40)\n",
"\n",
"# Start with a 1D tensor\n",
"a = torch.tensor([1, 2, 3, 4])\n",
"print(f\"Original: {a.shape} -> {a}\")\n",
"\n",
"# Add dimension at different positions\n",
"a_new_dim0 = a[None, :] # or a.unsqueeze(0)\n",
"print(f\"Add dim at 0: {a_new_dim0.shape} -> {a_new_dim0}\")\n",
"\n",
"a_new_dim1 = a[:, None] # or a.unsqueeze(1)\n",
"print(f\"Add dim at 1: {a_new_dim1.shape} -> {a_new_dim1}\")\n",
"\n",
"a_new_dim_end = a[..., None] # or a.unsqueeze(-1)\n",
"print(f\"Add dim at end: {a_new_dim_end.shape} -> {a_new_dim_end}\")\n",
"\n",
"# Multiple dimensions\n",
"a_multi = a[None, :, None, None]\n",
"print(f\"Multiple dims: {a_multi.shape}\")\n",
"\n",
"# TODO: Take tensor [1, 2, 3] and make it shape (1, 3, 1, 1)\n",
"# Your code here:\n",
"# result = ?"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "msnOOOQbS_4p",
"outputId": "684c5026-3648-4cf0-b136-be4cc53c089b"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"📝 Exercise 2: Adding Dimensions with None\n",
"----------------------------------------\n",
"Original: torch.Size([4]) -> tensor([1, 2, 3, 4])\n",
"Add dim at 0: torch.Size([1, 4]) -> tensor([[1, 2, 3, 4]])\n",
"Add dim at 1: torch.Size([4, 1]) -> tensor([[1],\n",
" [2],\n",
" [3],\n",
" [4]])\n",
"Add dim at end: torch.Size([4, 1]) -> tensor([[1],\n",
" [2],\n",
" [3],\n",
" [4]])\n",
"Multiple dims: torch.Size([1, 4, 1, 1])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# =============================================================================\n",
"# EXERCISE 3: Basic expand() Operation\n",
"# =============================================================================\n",
"print(\"\\n\\n📝 Exercise 3: Understanding expand()\")\n",
"print(\"-\" * 40)\n",
"\n",
"# expand() creates a view with repeated elements (no memory copy!)\n",
"b = torch.tensor([[1, 2, 3]]) # Shape: (1, 3)\n",
"print(f\"Original: {b.shape} -> {b}\")\n",
"\n",
"# Expand the first dimension\n",
"b_expanded = b.expand(4, 3) # Repeat the row 4 times\n",
"print(f\"Expanded: {b_expanded.shape}\")\n",
"print(b_expanded)\n",
"\n",
"# Expand with -1 (keep original size)\n",
"c = torch.tensor([[1], [2], [3]]) # Shape: (3, 1)\n",
"print(f\"\\nOriginal c: {c.shape}\")\n",
"print(c)\n",
"\n",
"c_expanded = c.expand(-1, 5) # Keep dim 0, expand dim 1 to 5\n",
"print(f\"Expanded c: {c_expanded.shape}\")\n",
"print(c_expanded)\n",
"\n",
"# TODO: Create tensor [[1, 2]] and expand it to shape (3, 4)\n",
"# Your code here:\n",
"# d = ?\n",
"# d_expanded = ?"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sHjNM722TDJc",
"outputId": "a3006f8c-5cbc-4cf0-9c85-7e0f562a005f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"📝 Exercise 3: Understanding expand()\n",
"----------------------------------------\n",
"Original: torch.Size([1, 3]) -> tensor([[1, 2, 3]])\n",
"Expanded: torch.Size([4, 3])\n",
"tensor([[1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3]])\n",
"\n",
"Original c: torch.Size([3, 1])\n",
"tensor([[1],\n",
" [2],\n",
" [3]])\n",
"Expanded c: torch.Size([3, 5])\n",
"tensor([[1, 1, 1, 1, 1],\n",
" [2, 2, 2, 2, 2],\n",
" [3, 3, 3, 3, 3]])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# =============================================================================\n",
"# EXERCISE 4: repeat() vs expand() vs repeat_interleave()\n",
"# =============================================================================\n",
"print(\"\\n\\n📝 Exercise 4: Different Repetition Methods\")\n",
"print(\"-\" * 40)\n",
"\n",
"original = torch.tensor([1, 2, 3])\n",
"print(f\"Original: {original}\")\n",
"\n",
"# Method 1: repeat() - actually copies data\n",
"repeated = original.repeat(2) # Repeat entire tensor 2 times\n",
"print(f\"repeat(2): {repeated}\")\n",
"\n",
"repeated_2d = original.repeat(2, 1) # 2D repetition\n",
"print(f\"repeat(2, 1) on [1,2,3]: shape {repeated_2d.shape}\")\n",
"print(repeated_2d)\n",
"\n",
"# Method 2: expand() - creates view (memory efficient)\n",
"original_2d = original.unsqueeze(0) # Make it (1, 3)\n",
"expanded = original_2d.expand(3, -1)\n",
"print(f\"expand(3, -1): shape {expanded.shape}\")\n",
"print(expanded)\n",
"\n",
"# Method 3: repeat_interleave() - repeats each element\n",
"interleaved = torch.repeat_interleave(original, 2)\n",
"print(f\"repeat_interleave(2): {interleaved}\")\n",
"\n",
"# TODO: What's the difference between these results?\n",
"# torch.tensor([1, 2]).repeat(3) vs torch.repeat_interleave(torch.tensor([1, 2]), 3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nCmQym5DTFUw",
"outputId": "cd8fd779-5a33-4157-90a6-1f9e8b921056"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"📝 Exercise 4: Different Repetition Methods\n",
"----------------------------------------\n",
"Original: tensor([1, 2, 3])\n",
"repeat(2): tensor([1, 2, 3, 1, 2, 3])\n",
"repeat(2, 1) on [1,2,3]: shape torch.Size([2, 3])\n",
"tensor([[1, 2, 3],\n",
" [1, 2, 3]])\n",
"expand(3, -1): shape torch.Size([3, 3])\n",
"tensor([[1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3]])\n",
"repeat_interleave(2): tensor([1, 1, 2, 2, 3, 3])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# =============================================================================\n",
"# EXERCISE 5: Working with 3D Tensors\n",
"# =============================================================================\n",
"print(\"\\n\\n📝 Exercise 5: 3D Tensor Manipulations\")\n",
"print(\"-\" * 40)\n",
"\n",
"# Create a 3D tensor: (batch=2, heads=3, features=4)\n",
"tensor_3d = torch.arange(24).reshape(2, 3, 4)\n",
"print(f\"3D tensor shape: {tensor_3d.shape}\")\n",
"print(f\"3D tensor:\\n{tensor_3d}\")\n",
"\n",
"# Add a dimension in the middle\n",
"tensor_4d = tensor_3d[:, :, None, :] # Shape: (2, 3, 1, 4)\n",
"print(f\"\\nAfter adding dim: {tensor_4d.shape}\")\n",
"\n",
"# Expand the new dimension\n",
"tensor_expanded = tensor_4d.expand(2, 3, 5, 4) # Shape: (2, 3, 5, 4)\n",
"print(f\"After expand: {tensor_expanded.shape}\")\n",
"print(f\"First batch, first head:\\n{tensor_expanded[0, 0]}\")\n",
"\n",
"# TODO: Take the expanded tensor and reshape it to merge dimensions 1 and 2\n",
"# Target shape: (2, 15, 4) # 3 * 5 = 15\n",
"# Your code here:\n",
"# merged = ?"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d_WBIO2NTGxt",
"outputId": "64430153-b0ec-478d-a01b-4ec78a3f53de"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"📝 Exercise 5: 3D Tensor Manipulations\n",
"----------------------------------------\n",
"3D tensor shape: torch.Size([2, 3, 4])\n",
"3D tensor:\n",
"tensor([[[ 0, 1, 2, 3],\n",
" [ 4, 5, 6, 7],\n",
" [ 8, 9, 10, 11]],\n",
"\n",
" [[12, 13, 14, 15],\n",
" [16, 17, 18, 19],\n",
" [20, 21, 22, 23]]])\n",
"\n",
"After adding dim: torch.Size([2, 3, 1, 4])\n",
"After expand: torch.Size([2, 3, 5, 4])\n",
"First batch, first head:\n",
"tensor([[0, 1, 2, 3],\n",
" [0, 1, 2, 3],\n",
" [0, 1, 2, 3],\n",
" [0, 1, 2, 3],\n",
" [0, 1, 2, 3]])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# =============================================================================\n",
"# EXERCISE 6: Simulating the repeat_kv Pattern\n",
"# =============================================================================\n",
"print(\"\\n\\n📝 Exercise 6: Building Up to repeat_kv\")\n",
"print(\"-\" * 40)\n",
"\n",
"# Simulate key/value heads that need to be repeated\n",
"# Shape: (batch, num_kv_heads, seq_len, head_dim)\n",
"kv_tensor = torch.arange(48).reshape(2, 3, 4, 2)\n",
"print(f\"KV tensor shape: {kv_tensor.shape}\")\n",
"print(f\"KV tensor (batch 0, head 0):\\n{kv_tensor[0, 0]}\")\n",
"\n",
"n_rep = 2 # Each KV head needs to be repeated 2 times\n",
"\n",
"# Step 1: Add dimension for repetition\n",
"step1 = kv_tensor[:, :, None, :, :] # Shape: (2, 3, 1, 4, 2)\n",
"print(f\"\\nStep 1 - Add dimension: {step1.shape}\")\n",
"\n",
"# Step 2: Expand the new dimension\n",
"step2 = step1.expand(2, 3, n_rep, 4, 2) # Shape: (2, 3, 2, 4, 2)\n",
"print(f\"Step 2 - Expand: {step2.shape}\")\n",
"\n",
"# Step 3: Reshape to merge heads\n",
"final = step2.reshape(2, 3 * n_rep, 4, 2) # Shape: (2, 6, 4, 2)\n",
"print(f\"Step 3 - Final: {final.shape}\")\n",
"\n",
"# Verify: each original head should appear n_rep times\n",
"print(f\"\\nOriginal head 0:\\n{kv_tensor[0, 0]}\")\n",
"print(f\"Repeated head 0 (position 0):\\n{final[0, 0]}\")\n",
"print(f\"Repeated head 0 (position 1):\\n{final[0, 1]}\")\n",
"print(f\"Original head 1:\\n{kv_tensor[0, 1]}\")\n",
"print(f\"Repeated head 1 (position 2):\\n{final[0, 2]}\")\n",
"\n",
"# TODO: Verify that final[0, 0] equals final[0, 1] (same repeated head)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mrdbTj81TIma",
"outputId": "a04923a7-e20b-43b9-c409-664d70e897a2"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"📝 Exercise 6: Building Up to repeat_kv\n",
"----------------------------------------\n",
"KV tensor shape: torch.Size([2, 3, 4, 2])\n",
"KV tensor (batch 0, head 0):\n",
"tensor([[0, 1],\n",
" [2, 3],\n",
" [4, 5],\n",
" [6, 7]])\n",
"\n",
"Step 1 - Add dimension: torch.Size([2, 3, 1, 4, 2])\n",
"Step 2 - Expand: torch.Size([2, 3, 2, 4, 2])\n",
"Step 3 - Final: torch.Size([2, 6, 4, 2])\n",
"\n",
"Original head 0:\n",
"tensor([[0, 1],\n",
" [2, 3],\n",
" [4, 5],\n",
" [6, 7]])\n",
"Repeated head 0 (position 0):\n",
"tensor([[0, 1],\n",
" [2, 3],\n",
" [4, 5],\n",
" [6, 7]])\n",
"Repeated head 0 (position 1):\n",
"tensor([[0, 1],\n",
" [2, 3],\n",
" [4, 5],\n",
" [6, 7]])\n",
"Original head 1:\n",
"tensor([[ 8, 9],\n",
" [10, 11],\n",
" [12, 13],\n",
" [14, 15]])\n",
"Repeated head 1 (position 2):\n",
"tensor([[ 8, 9],\n",
" [10, 11],\n",
" [12, 13],\n",
" [14, 15]])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# =============================================================================\n",
"# EXERCISE 7: Complete repeat_kv Implementation Practice\n",
"# =============================================================================\n",
"print(\"\\n\\n📝 Exercise 7: Implement repeat_kv Yourself\")\n",
"print(\"-\" * 40)\n",
"\n",
"def my_repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n",
" \"\"\"\n",
" TODO: Implement this function!\n",
" Input: (batch, num_key_value_heads, seqlen, head_dim)\n",
" Output: (batch, num_key_value_heads * n_rep, seqlen, head_dim)\n",
" \"\"\"\n",
" # Get dimensions\n",
" batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n",
"\n",
" # Handle n_rep = 1 case\n",
" if n_rep == 1:\n",
" return hidden_states\n",
"\n",
" # TODO: Your implementation here\n",
" # Step 1: Add dimension\n",
" # Step 2: Expand\n",
" # Step 3: Reshape\n",
" # return ?\n",
" pass\n",
"\n",
"# Test your implementation\n",
"test_tensor = torch.arange(24).reshape(1, 3, 4, 2)\n",
"print(f\"Test input shape: {test_tensor.shape}\")\n",
"\n",
"# TODO: Uncomment when you implement the function\n",
"# result = my_repeat_kv(test_tensor, 2)\n",
"# print(f\"Result shape: {result.shape}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TZ-7PONITKgn",
"outputId": "60135c06-7bbe-4b5e-cec1-71b7974c56d9"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"📝 Exercise 7: Implement repeat_kv Yourself\n",
"----------------------------------------\n",
"Test input shape: torch.Size([1, 3, 4, 2])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# =============================================================================\n",
"# EXERCISE 8: Advanced Patterns\n",
"# =============================================================================\n",
"print(\"\\n\\n📝 Exercise 8: Advanced Repetition Patterns\")\n",
"print(\"-\" * 40)\n",
"\n",
"# Pattern 1: Repeat different amounts for different dimensions\n",
"base = torch.tensor([[1, 2], [3, 4]]) # Shape: (2, 2)\n",
"print(f\"Base tensor:\\n{base}\")\n",
"\n",
"# Repeat 3 times along dim 0, 2 times along dim 1\n",
"pattern1 = base.repeat(3, 2)\n",
"print(f\"Pattern 1 - repeat(3, 2):\\n{pattern1}\")\n",
"\n",
"# Pattern 2: Using repeat_interleave along specific dimensions\n",
"pattern2 = torch.repeat_interleave(base, 2, dim=0)\n",
"print(f\"Pattern 2 - repeat_interleave along dim 0:\\n{pattern2}\")\n",
"\n",
"pattern3 = torch.repeat_interleave(base, 2, dim=1)\n",
"print(f\"Pattern 3 - repeat_interleave along dim 1:\\n{pattern3}\")\n",
"\n",
"# TODO: Create a pattern where you repeat_interleave along dim 0 with repeats [1, 3]\n",
"# (first row once, second row three times)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4nZkQy5HTMgC",
"outputId": "5da1f928-8e2e-47fe-8a66-c69b2bcf45e0"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"📝 Exercise 8: Advanced Repetition Patterns\n",
"----------------------------------------\n",
"Base tensor:\n",
"tensor([[1, 2],\n",
" [3, 4]])\n",
"Pattern 1 - repeat(3, 2):\n",
"tensor([[1, 2, 1, 2],\n",
" [3, 4, 3, 4],\n",
" [1, 2, 1, 2],\n",
" [3, 4, 3, 4],\n",
" [1, 2, 1, 2],\n",
" [3, 4, 3, 4]])\n",
"Pattern 2 - repeat_interleave along dim 0:\n",
"tensor([[1, 2],\n",
" [1, 2],\n",
" [3, 4],\n",
" [3, 4]])\n",
"Pattern 3 - repeat_interleave along dim 1:\n",
"tensor([[1, 1, 2, 2],\n",
" [3, 3, 4, 4]])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# =============================================================================\n",
"# EXERCISE 9: Memory Efficiency Test\n",
"# =============================================================================\n",
"print(\"\\n\\n📝 Exercise 9: Memory Usage Comparison\")\n",
"print(\"-\" * 40)\n",
"\n",
"large_tensor = torch.randn(100, 50)\n",
"print(f\"Original tensor memory: {large_tensor.numel() * large_tensor.element_size()} bytes\")\n",
"\n",
"# Method 1: expand (memory efficient - creates view)\n",
"expanded = large_tensor.unsqueeze(0).expand(10, -1, -1)\n",
"print(f\"Expanded tensor shares memory: {expanded.storage().data_ptr() == large_tensor.storage().data_ptr()}\")\n",
"\n",
"# Method 2: repeat (creates copy)\n",
"repeated = large_tensor.repeat(10, 1)\n",
"print(f\"Repeated tensor shares memory: {repeated.storage().data_ptr() == large_tensor.storage().data_ptr()}\")\n",
"\n",
"# TODO: What happens if you modify the original tensor? Will the expanded version change too?"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oo9lACoNTOCa",
"outputId": "3ae33886-d697-470d-9bf7-108217ffd4e4"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"📝 Exercise 9: Memory Usage Comparison\n",
"----------------------------------------\n",
"Original tensor memory: 20000 bytes\n",
"Expanded tensor shares memory: True\n",
"Repeated tensor shares memory: False\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# =============================================================================\n",
"# EXERCISE 10: Real-World Application - Attention Mechanism\n",
"# =============================================================================\n",
"print(\"\\n\\n📝 Exercise 10: Attention Mechanism Context\")\n",
"print(\"-\" * 40)\n",
"\n",
"# Simulate a real attention scenario\n",
"batch_size = 2\n",
"seq_len = 8\n",
"num_query_heads = 12\n",
"num_kv_heads = 4 # Fewer KV heads than query heads (Grouped Query Attention)\n",
"head_dim = 64\n",
"\n",
"print(f\"Scenario: {num_query_heads} query heads, {num_kv_heads} KV heads\")\n",
"print(f\"Need to repeat each KV head {num_query_heads // num_kv_heads} times\")\n",
"\n",
"# Create mock key and value tensors\n",
"keys = torch.randn(batch_size, num_kv_heads, seq_len, head_dim)\n",
"values = torch.randn(batch_size, num_kv_heads, seq_len, head_dim)\n",
"\n",
"print(f\"Original keys shape: {keys.shape}\")\n",
"print(f\"Original values shape: {values.shape}\")\n",
"\n",
"# Apply repeat_kv to match query heads\n",
"n_rep = num_query_heads // num_kv_heads\n",
"\n",
"def repeat_kv_solution(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n",
" batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n",
" if n_rep == 1:\n",
" return hidden_states\n",
" hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n",
" return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n",
"\n",
"keys_repeated = repeat_kv_solution(keys, n_rep)\n",
"values_repeated = repeat_kv_solution(values, n_rep)\n",
"\n",
"print(f\"Repeated keys shape: {keys_repeated.shape}\")\n",
"print(f\"Repeated values shape: {values_repeated.shape}\")\n",
"\n",
"# TODO: Verify that we now have the same number of heads for keys, values, and queries\n",
"print(f\"Success! KV heads now match query heads: {keys_repeated.shape[1] == num_query_heads}\")\n",
"\n",
"print(\"\\n🎉 Exercises Complete!\")\n",
"print(\"Next steps:\")\n",
"print(\"1. Try modifying the dimensions and see how shapes change\")\n",
"print(\"2. Experiment with different n_rep values\")\n",
"print(\"3. Compare memory usage between expand() and repeat()\")\n",
"print(\"4. Implement your own version of repeat_kv from scratch!\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "giqEI1vBS26T",
"outputId": "a30c8518-4b2f-4393-9f79-76d8c0f854c9"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"📝 Exercise 10: Attention Mechanism Context\n",
"----------------------------------------\n",
"Scenario: 12 query heads, 4 KV heads\n",
"Need to repeat each KV head 3 times\n",
"Original keys shape: torch.Size([2, 4, 8, 64])\n",
"Original values shape: torch.Size([2, 4, 8, 64])\n",
"Repeated keys shape: torch.Size([2, 12, 8, 64])\n",
"Repeated values shape: torch.Size([2, 12, 8, 64])\n",
"Success! KV heads now match query heads: True\n",
"\n",
"🎉 Exercises Complete!\n",
"Next steps:\n",
"1. Try modifying the dimensions and see how shapes change\n",
"2. Experiment with different n_rep values\n",
"3. Compare memory usage between expand() and repeat()\n",
"4. Implement your own version of repeat_kv from scratch!\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZHoK7wtV9nnV"
},
"source": [
" ## 5. Muon Optimizer - The Secret Sauce\n",
"\n",
"\n",
"\n",
" The Muon optimizer is a novel approach that uses Newton-Schulz iteration for orthogonalization.\n",
"\n",
" This helps with training stability and convergence. The `zeropower_via_newtonschulz5` function\n",
"\n",
" implements the core mathematical operation."
]
},
{
"cell_type": "markdown",
"source": [
"To really *understand* the **Muon optimizer** (🔥 the best new optimizers), I recommend checking out these tutorials:\n",
"\n",
"- 🔁 [Backpropagation From Scratch](https://youtu.be/W8g1hvW4Wic) — Understand gradients deeply\n",
"- 🧠 [Orthonormal Matrix Intuition](https://youtu.be/FbYRZpBgFz4) — Key concept behind Muon’s update step\n",
"\n",
"These will level up your optimizer intuition 💪✨"
],
"metadata": {
"id": "hH6FapirUzaV"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P7jXXxGz9nnW"
},
"outputs": [],
"source": [
"@torch.compile\n",
"def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:\n",
" \"\"\"Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.\"\"\"\n",
" assert G.ndim >= 2\n",
" a, b, c = (3.4445, -4.7750, 2.0315)\n",
" X = G.bfloat16()\n",
"\n",
" if G.size(-2) > G.size(-1):\n",
" X = X.mT\n",
"\n",
" X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)\n",
"\n",
" for _ in range(steps):\n",
" A = X @ X.mT\n",
" B = b * A + c * A @ A\n",
" X = a * X + B @ X\n",
"\n",
" if G.size(-2) > G.size(-1):\n",
" X = X.mT\n",
"\n",
" return X\n",
"\n",
"class Muon(torch.optim.Optimizer):\n",
" \"\"\"Muon - MomentUm Orthogonalized by Newton-schulz\"\"\"\n",
" def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):\n",
" defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)\n",
" super().__init__(params, defaults)\n",
"\n",
" @torch.no_grad()\n",
" def step(self):\n",
" for group in self.param_groups:\n",
" for p in group[\"params\"]:\n",
" if p.grad is None:\n",
" continue\n",
"\n",
" g = p.grad\n",
" state = self.state[p]\n",
"\n",
" # Initialize momentum buffer if first time\n",
" if \"momentum_buffer\" not in state:\n",
" state[\"momentum_buffer\"] = torch.zeros_like(g)\n",
"\n",
" buf = state[\"momentum_buffer\"]\n",
" # Update momentum buffer: buf = momentum * buf + (1-momentum) * grad\n",
" buf.lerp_(g, 1 - group[\"momentum\"])\n",
" # Apply Nesterov momentum if enabled, otherwise use standard momentum\n",
" g = g.lerp_(buf, group[\"momentum\"]) if group[\"nesterov\"] else buf\n",
" # Apply zero-power normalization via Newton-Schulz iterations (make it close to orthonormal)\n",
" g = zeropower_via_newtonschulz5(g, steps=group[\"ns_steps\"])\n",
" # Update parameters with adaptive scaling based on parameter shape\n",
" p.add_(g.view_as(p), alpha=-group[\"lr\"] * max(1, p.size(-2) / p.size(-1))**0.5)\n",
" # Updates parameters with an adaptive learning rate that scales based on the parameter tensor's aspect ratio (height/width). For matrices where height > width, it increases the effective learning rate by √(height/width)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X0RdGidT9nnW"
},
"source": [
" ## 6. Data Loading and Caching\n",
"\n",
"\n",
"\n",
" Loading and processing data can be time-consuming. We implement caching to avoid\n",
"\n",
" reprocessing the same data multiple times. This function loads the SmolLM corpus\n",
"\n",
" and tokenizes it efficiently."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ql6arQ7J9nnW"
},
"outputs": [],
"source": [
"def load_and_cache_data(config: ModelConfig, cache_dir: str = \"data_cache\"):\n",
" \"\"\"Load and cache tokenized data to avoid reprocessing\"\"\"\n",
" os.makedirs(cache_dir, exist_ok=True)\n",
" cache_file = f\"{cache_dir}/tokenized_data_{config.num_documents}_{config.max_tokens}.pkl\"\n",
"\n",
" # Check if cached data exists\n",
" if os.path.exists(cache_file):\n",
" print(f\"📦 Loading cached data from {cache_file}\")\n",
" with open(cache_file, 'rb') as f:\n",
" cached_data = pickle.load(f)\n",
"\n",
" texts = cached_data['texts']\n",
" tokenizer = cached_data['tokenizer']\n",
" tokens = cached_data['tokens']\n",
" config.vocab_size = tokenizer.vocab_size\n",
"\n",
" print(f\"✅ Loaded {len(texts)} documents, {len(tokens):,} tokens from cache\")\n",
" return texts, tokenizer, tokens\n",
"\n",
" print(f\"🔄 Processing new data (will cache for future use)\")\n",
"\n",
" # Load tokenizer\n",
" tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/SmolLM-135M\")\n",
" if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
" # Load dataset\n",
" dataset = load_dataset(\"HuggingFaceTB/smollm-corpus\", \"cosmopedia-v2\", split=\"train\", streaming=True)\n",
"\n",
" texts = []\n",
" for i, item in enumerate(dataset):\n",
" if i >= config.num_documents:\n",
" break\n",
" texts.append(item[\"text\"][:3000])\n",
"\n",
" print(f\"Loaded {len(texts)} documents\")\n",
"\n",
" # Tokenize\n",
" print(\"Tokenizing texts...\")\n",
" all_tokens = []\n",
" for text in tqdm(texts, desc=\"Tokenizing\"):\n",
" tokens = tokenizer.encode(text, add_special_tokens=False)\n",
" all_tokens.extend(tokens)\n",
"\n",
" tokens = all_tokens[:config.max_tokens]\n",
" print(f\"Using {len(tokens):,} tokens\")\n",
" config.vocab_size = tokenizer.vocab_size\n",
"\n",
" # Cache the processed data\n",
" cached_data = {'texts': texts, 'tokenizer': tokenizer, 'tokens': tokens}\n",
" with open(cache_file, 'wb') as f:\n",
" pickle.dump(cached_data, f)\n",
"\n",
" print(f\"💾 Cached data to {cache_file}\")\n",
" return texts, tokenizer, tokens\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rBsAbVgW9nnW"
},
"source": [
" ## 7. Dataset Class\n",
"\n",
"\n",
"\n",
" We create a custom dataset class for language modeling. This creates sliding windows\n",
"\n",
" of tokens for training, where each sample is a sequence and its corresponding target\n",
"\n",
" (shifted by one position)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ccgG3UnR9nnW"
},
"outputs": [],
"source": [
"class TextTokenDataset(Dataset):\n",
" def __init__(self, tokens: List[int], seq_len: int = 512):\n",
" self.tokens = tokens\n",
" self.seq_len = seq_len\n",
"\n",
" def __len__(self):\n",
" return max(0, len(self.tokens) - self.seq_len)\n",
"\n",
" def __getitem__(self, idx):\n",
" x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)\n",
" y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)\n",
" return x, y\n"
]
},
{
"cell_type": "markdown",
"source": [
"💬 [ChatGPT explanation and excercise](https://chatgpt.com/share/689455f2-4598-8002-9ac2-f8c3913ecad7)"
],
"metadata": {
"id": "CkBX_TX4b-Ra"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "9hXTnH-z9nnX"
},
"source": [
" ## 8. Rotary Position Embeddings (RoPE)\n",
"\n",
"\n",
"\n",
" RoPE is a modern alternative to positional encodings that allows the model to\n",
"\n",
" generalize to longer sequences. It applies rotation matrices to the embeddings\n",
"\n",
" based on their position."
]
},
{
"cell_type": "markdown",
"source": [
"My videos explaining RoPE (first is most important)\n",
"\n",
"- 📌 **[Rotary Positional Embeddings & Rotation Matrix + Python LLM Code](https://youtu.be/wiJ-OU-URYg)**\n",
"- 🧠 **[Get SMARTER Than 99% of AI Researchers](https://youtu.be/X0JryI85hL0)** - Beginning part\n",
"- 🛠️ **[RoPE In DeepSeek V3 – Code Step by Step](https://youtu.be/Rs9tLDSMUkM)**\n"
],
"metadata": {
"id": "2Rm57VsNeH-q"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WuuNzqYb9nnX"
},
"outputs": [],
"source": [
"class Rotary(nn.Module):\n",
" def __init__(self, dim: int, max_seq_len: int):\n",
" super().__init__()\n",
" angular_freq = (1 / 10000) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)\n",
" angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])\n",
" t = torch.arange(max_seq_len, dtype=torch.float32)\n",
" theta = torch.einsum(\"i,j -> ij\", t, angular_freq)\n",
" self.register_buffer('cos', theta.cos(), persistent=False)\n",
" self.register_buffer('sin', theta.sin(), persistent=False)\n",
"\n",
" def forward(self, x_BTHD: torch.Tensor):\n",
" assert self.cos.size(0) >= x_BTHD.size(-3)\n",
" cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]\n",
" x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)\n",
" y1 = x1 * cos + x2 * sin\n",
" y2 = x1 * (-sin) + x2 * cos\n",
" return torch.cat((y1, y2), 3).type_as(x_BTHD)\n"
]
},
{
"cell_type": "markdown",
"source": [
"[Excercises with ChatGPT](https://chatgpt.com/share/68945a01-8d48-8002-8cf0-04b7f6db744b)"
],
"metadata": {
"id": "zZKOKGxpgEKl"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCJaNEH49nnX"
},
"source": [
" ## 9. Grouped-Query Attention Implementation\n",
"\n",
"\n",
"\n",
" This is the heart of our model - the attention mechanism with GQA. Notice how we:\n",
"\n",
" 1. Project Q, K, V separately\n",
"\n",
" 2. Apply QK normalization (a Qwen3 innovation)\n",
"\n",
" 3. Use RoPE for positional information\n",
"\n",
" 4. Implement GQA by repeating K and V heads\n",
"\n",
" 5. Use scaled dot-product attention"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7gASpg2l9nnX"
},
"outputs": [],
"source": [
"class Qwen3Attention(nn.Module):\n",
" def __init__(self, config: ModelConfig):\n",
" super().__init__()\n",
" self.d_model = config.d_model\n",
" self.n_heads = config.n_heads\n",
" self.n_kv_heads = config.n_kv_heads\n",
" self.n_kv_groups = config.n_kv_groups\n",
" self.d_k = config.d_k\n",
"\n",
" # Separate linear layers for Q, K, V\n",
" self.q_proj = nn.Linear(self.d_model, self.n_heads * self.d_k, bias=config.attention_bias)\n",
" self.k_proj = nn.Linear(self.d_model, self.n_kv_heads * self.d_k, bias=config.attention_bias)\n",
" self.v_proj = nn.Linear(self.d_model, self.n_kv_heads * self.d_k, bias=config.attention_bias)\n",
" self.w_o = nn.Linear(self.d_model, self.d_model, bias=False)\n",
"\n",
" # QK-Normalization layers\n",
" # Practice RMSNorm 1 on 1 with ChatGPT - https://chatgpt.com/share/68945c86-2dd4-8002-b017-725caab0c107\n",
" self.q_norm = nn.RMSNorm(self.d_k, eps=config.rms_norm_eps)\n",
" self.k_norm = nn.RMSNorm(self.d_k, eps=config.rms_norm_eps)\n",
"\n",
" self.rotary = Rotary(self.d_k, config.max_seq_len)\n",
" self.dropout = config.dropout\n",
"\n",
" def forward(self, x):\n",
" batch_size, seq_len = x.size(0), x.size(1)\n",
"\n",
" # 1. Project Q, K, V separately\n",
" q = self.q_proj(x)\n",
" k = self.k_proj(x)\n",
" v = self.v_proj(x)\n",
"\n",
" # 2. Reshape into heads\n",
" q = q.view(batch_size, seq_len, self.n_heads, self.d_k)\n",
" k = k.view(batch_size, seq_len, self.n_kv_heads, self.d_k)\n",
" v = v.view(batch_size, seq_len, self.n_kv_heads, self.d_k)\n",
"\n",
" # 3. Apply QK-Norm\n",
" q = self.q_norm(q)\n",
" k = self.k_norm(k)\n",
"\n",
" # 4. Apply RoPE\n",
" # Transpose to (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k) for rotary\n",
" q = self.rotary(q.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)\n",
" k = self.rotary(k.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)\n",
"\n",
" # Transpose for attention: (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k)\n",
" Q = q.transpose(1, 2)\n",
" K = k.transpose(1, 2)\n",
" V = v.transpose(1, 2)\n",
"\n",
" # 5. Repeat K and V heads for GQA\n",
" K = repeat_kv(K, self.n_kv_groups)\n",
" V = repeat_kv(V, self.n_kv_groups)\n",
"\n",
" # 6. Scaled Dot-Product Attention\n",
" attn_output = F.scaled_dot_product_attention(\n",
" Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0\n",
" )\n",
"\n",
" # 7. Reshape and final projection\n",
" attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)\n",
" return self.w_o(attn_output)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SKYp-oKf9nnX"
},
"source": [
" ## 10. SwiGLU Feed-Forward Network\n",
"\n",
"\n",
"\n",
" SwiGLU is a modern activation function that combines Swish and GLU. It's more\n",
"\n",
" effective than traditional ReLU and is used in many modern models including Qwen3."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "05g5swnW9nnX"
},
"outputs": [],
"source": [
"class SwiGLUFeedForward(nn.Module):\n",
" def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):\n",
" super().__init__()\n",
" self.gate_proj = nn.Linear(d_model, d_ff, bias=False)\n",
" self.down_proj = nn.Linear(d_ff, d_model, bias=False)\n",
" self.up_proj = nn.Linear(d_model, d_ff, bias=False)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x):\n",
" # Implementation of the SwiGLU activation function\n",
" # F.silu is the Swish activation function\n",
" activated_x = F.silu(self.gate_proj(x)) * self.up_proj(x)\n",
" return self.down_proj(self.dropout(activated_x))\n"
]
},
{
"cell_type": "markdown",
"source": [
"Think of:\n",
"\n",
"`output = gate(x) * value(x)`\n",
"\n",
"like:\n",
"\n",
"`light = brightness_control × light_source`\n"
],
"metadata": {
"id": "Fut7QDnDpT1H"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "cJ0YpmSk9nnX"
},
"source": [
" ## 11. Transformer Block\n",
"\n",
"\n",
"\n",
" Each transformer block combines attention and feed-forward layers with residual\n",
"\n",
" connections and normalization. We use RMSNorm instead of LayerNorm for better\n",
"\n",
" training stability."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dxaNJjGQ9nnX"
},
"outputs": [],
"source": [
"class TransformerBlock(nn.Module):\n",
" def __init__(self, config: ModelConfig): # Pass the entire config object\n",
" super().__init__()\n",
" self.attention = Qwen3Attention(config)\n",
" self.feed_forward = SwiGLUFeedForward(config.d_model, config.d_ff, config.dropout)\n",
" self.norm1 = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)\n",
" self.norm2 = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)\n",
" self.dropout = nn.Dropout(config.dropout)\n",
"\n",
" def forward(self, x):\n",
" attn_out = self.attention(self.norm1(x))\n",
" x = x + self.dropout(attn_out)\n",
" ff_out = self.feed_forward(self.norm2(x))\n",
" x = x + self.dropout(ff_out)\n",
" return x\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hg9Sk2mV9nnX"
},
"source": [
" ## 12. Complete Language Model\n",
"\n",
"\n",
"\n",
" Now we assemble everything into our complete language model. This includes:\n",
"\n",
" - Token embeddings\n",
"\n",
" - Positional dropout\n",
"\n",
" - Stack of transformer blocks\n",
"\n",
" - Final normalization and output projection\n",
"\n",
" - Weight tying between input embeddings and output layer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TacWbKLz9nnX"
},
"outputs": [],
"source": [
"class MinimalLLM(nn.Module):\n",
" def __init__(self, config: ModelConfig):\n",
" super().__init__()\n",
" self.config = config\n",
"\n",
" self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)\n",
" self.position_dropout = nn.Dropout(config.dropout)\n",
"\n",
" self.transformer_blocks = nn.ModuleList([\n",
" TransformerBlock(config) for _ in range(config.n_layers)\n",
" ])\n",
"\n",
" self.norm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)\n",
" self.output_dropout = nn.Dropout(config.dropout)\n",
"\n",
" # Tie weights\n",
" # This ties the output layer (`lm_head`) weights to the input token embedding weights so the model shares parameters between input and output, reducing memory and improving generalization.\n",
" # https://chatgpt.com/share/6894683e-ba44-8002-ae82-e42b4afc9d98\n",
" self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)\n",
" self.lm_head.weight = self.token_embedding.weight\n",
"\n",
" self.apply(self._init_weights)\n",
"\n",
" def _init_weights(self, module):\n",
" if isinstance(module, nn.Linear):\n",
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
" if module.bias is not None:\n",
" torch.nn.init.zeros_(module.bias)\n",
" elif isinstance(module, nn.Embedding):\n",
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
"\n",
" def forward(self, x):\n",
" x = self.token_embedding(x) * math.sqrt(self.config.d_model)\n",
" x = self.position_dropout(x)\n",
"\n",
" for block in self.transformer_blocks:\n",
" x = block(x)\n",
"\n",
" x = self.norm(x)\n",
" x = self.output_dropout(x)\n",
" logits = self.lm_head(x)\n",
" return logits\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_w5nsjPZ9nnX"
},
"source": [
" ## 13. Evaluation Function\n",
"\n",
"\n",
"\n",
" During training, we need to evaluate our model's performance. This function\n",
"\n",
" computes loss, accuracy, and perplexity on the validation set."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p6Ov0uO29nnY"
},
"outputs": [],
"source": [
"def evaluate_model(model: nn.Module, val_loader: DataLoader, config: ModelConfig):\n",
" \"\"\"Evaluate model performance\"\"\"\n",
" model.eval()\n",
" total_loss = 0\n",
" total_tokens = 0\n",
" total_correct = 0\n",
"\n",
" device = next(model.parameters()).device\n",
"\n",
" with torch.no_grad(): # Disable gradient computation for evaluation (saves memory and computation)\n",
" for i, (x, y) in enumerate(val_loader):\n",
" # Stop evaluation after specified number of steps to limit eval time\n",
" if i >= config.eval_steps:\n",
" break\n",
"\n",
" # Move input sequences (x) and target sequences (y) to GPU/device\n",
" x, y = x.to(device), y.to(device)\n",
"\n",
" # Use automatic mixed precision if enabled (faster training with minimal accuracy loss)\n",
" with autocast(enabled=config.use_amp):\n",
" # Forward pass: get model predictions (logits) for input sequence\n",
" logits = model(x)\n",
"\n",
" # Calculate cross-entropy loss between predictions and targets\n",
" # Reshape to (batch_size * seq_len, vocab_size) and (batch_size * seq_len,)\n",
" # for proper cross-entropy computation across all token positions\n",
" loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))\n",
"\n",
" # Accumulate total loss weighted by number of tokens in this batch\n",
" total_loss += loss.item() * y.numel()\n",
" # Keep track of total number of tokens processed\n",
" total_tokens += y.numel()\n",
"\n",
" # Get predicted token IDs by taking argmax over vocabulary dimension\n",
" predictions = logits.argmax(dim=-1)\n",
" # Count correct predictions for accuracy calculation\n",
" total_correct += (predictions == y).sum().item()\n",
"\n",
" avg_loss = total_loss / total_tokens\n",
" accuracy = total_correct / total_tokens\n",
" perplexity = math.exp(min(avg_loss, 20))\n",
"\n",
" model.train()\n",
" return {'val_loss': avg_loss, 'val_accuracy': accuracy, 'val_perplexity': perplexity}\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4qjZRIpd9nnY"
},
"source": [
" ## 14. Optimizer Setup\n",
"\n",
"\n",
"\n",
" We use a hybrid approach: Muon optimizer for 2D parameters (attention and feed-forward weights)\n",
"\n",
" and AdamW for other parameters. This gives us the benefits of both optimizers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o3gfmjW29nnY"
},
"outputs": [],
"source": [
"def setup_muon_optimizer(model: nn.Module, config: ModelConfig):\n",
" \"\"\"Setup Muon optimizer with hybrid approach\"\"\"\n",
" muon_params = []\n",
" adamw_params = []\n",
"\n",
" for name, param in model.named_parameters():\n",
" if (param.ndim == 2 and\n",
" 'token_embedding' not in name and\n",
" 'norm' not in name and\n",
" param.requires_grad):\n",
" muon_params.append(param)\n",
" else:\n",
" adamw_params.append(param)\n",
"\n",
" print(f\" Muon parameters: {sum(p.numel() for p in muon_params):,}\")\n",
" print(f\" AdamW parameters: {sum(p.numel() for p in adamw_params):,}\")\n",
"\n",
" muon_optimizer = Muon(muon_params, lr=config.muon_lr, momentum=0.95)\n",
" adamw_optimizer = torch.optim.AdamW(adamw_params, lr=config.muon_lr*0.1, weight_decay=config.weight_decay)\n",
"\n",
" return [muon_optimizer, adamw_optimizer]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Y_5AhJM9nnY"
},
"source": [
" ## 15. Training Loop\n",
"\n",
"\n",
"\n",
" This is where the magic happens! Our training loop includes:\n",
"\n",
" - Gradient accumulation for larger effective batch sizes\n",
"\n",
" - Mixed precision training for speed\n",
"\n",
" - Learning rate scheduling with warmup and cosine decay\n",
"\n",
" - Regular evaluation and model checkpointing\n",
"\n",
" - Progress tracking with detailed metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RT4qQBd_9nnY"
},
"outputs": [],
"source": [
"def train_model(config: ModelConfig, train_loader: DataLoader, val_loader: DataLoader):\n",
" \"\"\"Train the model with Muon optimizer\"\"\"\n",
" print(f\"\\n🚀 Training Small model with Muon optimizer\")\n",
"\n",
" # Initialize model\n",
" set_seed(42)\n",
" model = MinimalLLM(config)\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" model = model.to(device)\n",
"\n",
" total_params = sum(p.numel() for p in model.parameters())\n",
" print(f\" 📊 Total parameters: {total_params:,}\")\n",
"\n",
" # Setup optimizers\n",
" optimizers = setup_muon_optimizer(model, config)\n",
"\n",
" # Learning rate schedule\n",
" schedulers = []\n",
" for optimizer in optimizers:\n",
" warmup_steps = config.max_steps // 20\n",
" def lr_lambda(step):\n",
" if step < warmup_steps:\n",
" return step / warmup_steps\n",
" else:\n",
" progress = (step - warmup_steps) / (config.max_steps - warmup_steps)\n",
" return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))\n",
"\n",
" scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n",
" schedulers.append(scheduler)\n",
"\n",
" scaler = GradScaler() if config.use_amp else None\n",
"\n",
" # Training loop\n",
" model.train()\n",
" step = 0\n",
" start_time = time.time()\n",
" best_val_loss = float('inf')\n",
"\n",
" pbar = tqdm(total=config.max_steps, desc=\"Training\")\n",
"\n",
" while step < config.max_steps:\n",
" for batch_idx, (x, y) in enumerate(train_loader):\n",
" if step >= config.max_steps:\n",
" break\n",
"\n",
" x, y = x.to(device), y.to(device)\n",
"\n",
" # Forward pass with gradient accumulation\n",
" if config.use_amp:\n",
" with autocast():\n",
" logits = model(x)\n",
" loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))\n",
" loss = loss / config.gradient_accumulation_steps\n",
" scaler.scale(loss).backward()\n",
" else:\n",
" logits = model(x)\n",
" loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))\n",
" loss = loss / config.gradient_accumulation_steps\n",
" loss.backward()\n",
"\n",
" # Optimizer step after accumulation\n",
" if (step + 1) % config.gradient_accumulation_steps == 0:\n",
" if config.use_amp:\n",
" for optimizer in optimizers:\n",
" scaler.unscale_(optimizer)\n",
" grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)\n",
"\n",
" for optimizer in optimizers:\n",
" scaler.step(optimizer)\n",
" optimizer.zero_grad()\n",
" for scheduler in schedulers:\n",
" scheduler.step()\n",
" scaler.update()\n",
" else:\n",
" grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)\n",
" for optimizer in optimizers:\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" for scheduler in schedulers:\n",
" scheduler.step()\n",
"\n",
" # Logging\n",
" if step % 10 == 0:\n",
" with torch.no_grad():\n",
" predictions = logits.argmax(dim=-1)\n",
" accuracy = (predictions == y).float().mean().item()\n",
" current_loss = loss.item() * config.gradient_accumulation_steps\n",
" perplexity = math.exp(min(current_loss, 20))\n",
"\n",
" pbar.set_postfix({\n",
" 'loss': f'{current_loss:.4f}',\n",
" 'acc': f'{accuracy:.3f}',\n",
" 'ppl': f'{perplexity:.1f}',\n",
" 'lr': f'{optimizers[0].param_groups[0][\"lr\"]:.2e}'\n",
" })\n",
"\n",
" # Evaluation\n",
" if step % config.eval_every == 0 and step > 0:\n",
" eval_metrics = evaluate_model(model, val_loader, config)\n",
" print(f\"\\nStep {step}: Val Loss: {eval_metrics['val_loss']:.4f}, \"\n",
" f\"Val Acc: {eval_metrics['val_accuracy']:.4f}, \"\n",
" f\"Val PPL: {eval_metrics['val_perplexity']:.2f}\")\n",
"\n",
" if eval_metrics['val_loss'] < best_val_loss:\n",
" best_val_loss = eval_metrics['val_loss']\n",
" # Save best model\n",
" torch.save({\n",
" 'model_state_dict': model.state_dict(),\n",
" 'config': config,\n",
" 'step': step,\n",
" 'best_val_loss': best_val_loss,\n",
" 'final_metrics': eval_metrics\n",
" }, 'best_model.pt')\n",
" print(f\"💾 Saved best model with val_loss: {best_val_loss:.4f}\")\n",
"\n",
" step += 1\n",
" if step % 10 == 0:\n",
" pbar.update(10)\n",
"\n",
" pbar.close()\n",
"\n",
" training_time = time.time() - start_time\n",
" print(f\" ⏱️ Training completed in {training_time:.1f} seconds\")\n",
"\n",
" # Final evaluation\n",
" final_eval = evaluate_model(model, val_loader, config)\n",
" print(f\" 📊 Final - Loss: {final_eval['val_loss']:.4f}, \"\n",
" f\"Acc: {final_eval['val_accuracy']:.4f}, PPL: {final_eval['val_perplexity']:.2f}\")\n",
"\n",
" # Save final model\n",
" torch.save({\n",
" 'model_state_dict': model.state_dict(),\n",
" 'config': config,\n",
" 'step': step,\n",
" 'final_metrics': final_eval\n",
" }, 'final_model.pt')\n",
" print(f\"💾 Saved final model to final_model.pt\")\n",
"\n",
" return model, final_eval\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YpRsvalr9nnY"
},
"source": [
" ## 16. Main Training Script\n",
"\n",
"\n",
"\n",
" Finally, let's put everything together! This section:\n",
"\n",
" 1. Checks system resources\n",
"\n",
" 2. Sets up configuration\n",
"\n",
" 3. Loads and prepares data\n",
"\n",
" 4. Trains the model\n",
"\n",
" 5. Reports final results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2ufOMN_19nnY",
"outputId": "be3d4e0c-3874-4ff6-bae2-1657d49cf514"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"🔍 Device: CUDA\n",
"GPU: Tesla T4\n",
"Memory: 15.8 GB\n",
"🌱 Set all seeds to 42\n",
"\n",
"📋 Model Configuration:\n",
" Architecture: 384d, 6L, 8H, 1536ff\n",
" Training: 2000 steps, batch size 24\n",
" Data: 500,000 tokens, seq_len 512\n",
"📦 Loading cached data from data_cache/tokenized_data_2000_500000.pkl\n",
"✅ Loaded 2000 documents, 500,000 tokens from cache\n",
"📊 Dataset: 449540 train, 49948 val samples\n",
"\n",
"🚀 Training Small model with Muon optimizer\n",
"🌱 Set all seeds to 42\n",
" 📊 Total parameters: 32,150,976\n",
" Muon parameters: 13,271,040\n",
" AdamW parameters: 18,879,936\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n",
"Training: 0%| | 0/2000 [00:00<?, ?it/s]\u001b[A\n",
"Training: 0%| | 0/2000 [00:00<?, ?it/s, loss=10.8028, acc=0.015, ppl=49156.2, lr=0.00e+00]\u001b[AW0807 09:01:46.185000 1133 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode\n",
"\n",
"Training: 0%| | 10/2000 [00:06<22:57, 1.44it/s, loss=10.8028, acc=0.015, ppl=49156.2, lr=0.00e+00]\u001b[A\n",
"Training: 0%| | 10/2000 [00:07<22:57, 1.44it/s, loss=10.8007, acc=0.014, ppl=49055.4, lr=2.00e-04]\u001b[A\n",
"Training: 1%| | 20/2000 [00:10<16:20, 2.02it/s, loss=10.8007, acc=0.014, ppl=49055.4, lr=2.00e-04]\u001b[A\n",
"Training: 1%| | 20/2000 [00:10<16:20, 2.02it/s, loss=10.7808, acc=0.016, ppl=48087.5, lr=5.00e-04]\u001b[A\n",
"Training: 2%|▏ | 30/2000 [00:13<13:31, 2.43it/s, loss=10.7808, acc=0.016, ppl=48087.5, lr=5.00e-04]\u001b[A\n",
"Training: 2%|▏ | 30/2000 [00:14<13:31, 2.43it/s, loss=10.7480, acc=0.016, ppl=46535.0, lr=7.00e-04]\u001b[A\n",
"Training: 2%|▏ | 40/2000 [00:17<12:47, 2.55it/s, loss=10.7480, acc=0.016, ppl=46535.0, lr=7.00e-04]\u001b[A\n",
"Training: 2%|▏ | 40/2000 [00:17<12:47, 2.55it/s, loss=10.7048, acc=0.016, ppl=44568.7, lr=1.00e-03]\u001b[A\n",
"Training: 2%|▎ | 50/2000 [00:20<11:51, 2.74it/s, loss=10.7048, acc=0.016, ppl=44568.7, lr=1.00e-03]\u001b[A\n",
"Training: 2%|▎ | 50/2000 [00:20<11:51, 2.74it/s, loss=10.6530, acc=0.016, ppl=42319.6, lr=1.20e-03]\u001b[A\n",
"Training: 3%|▎ | 60/2000 [00:24<11:47, 2.74it/s, loss=10.6530, acc=0.016, ppl=42319.6, lr=1.20e-03]\u001b[A\n",
"Training: 3%|▎ | 60/2000 [00:24<11:47, 2.74it/s, loss=10.5811, acc=0.015, ppl=39383.3, lr=1.50e-03]\u001b[A\n",
"Training: 4%|▎ | 70/2000 [00:27<11:15, 2.86it/s, loss=10.5811, acc=0.015, ppl=39383.3, lr=1.50e-03]\u001b[A\n",
"Training: 4%|▎ | 70/2000 [00:27<11:15, 2.86it/s, loss=10.4804, acc=0.015, ppl=35609.5, lr=1.70e-03]\u001b[A\n",
"Training: 4%|▍ | 80/2000 [00:30<11:22, 2.81it/s, loss=10.4804, acc=0.015, ppl=35609.5, lr=1.70e-03]\u001b[A\n",
"Training: 4%|▍ | 80/2000 [00:31<11:22, 2.81it/s, loss=10.3496, acc=0.015, ppl=31245.1, lr=2.00e-03]\u001b[A\n",
"Training: 4%|▍ | 90/2000 [00:34<10:58, 2.90it/s, loss=10.3496, acc=0.015, ppl=31245.1, lr=2.00e-03]\u001b[A\n",
"Training: 4%|▍ | 90/2000 [00:34<10:58, 2.90it/s, loss=10.1831, acc=0.015, ppl=26452.4, lr=2.20e-03]\u001b[A\n",
"Training: 5%|▌ | 100/2000 [00:37<11:10, 2.83it/s, loss=10.1831, acc=0.015, ppl=26452.4, lr=2.20e-03]\u001b[A\n",
"Training: 5%|▌ | 100/2000 [00:38<11:10, 2.83it/s, loss=9.9458, acc=0.014, ppl=20865.1, lr=2.50e-03] \u001b[A\n",
"Training: 6%|▌ | 110/2000 [00:41<10:59, 2.87it/s, loss=9.9458, acc=0.014, ppl=20865.1, lr=2.50e-03]\u001b[A\n",
"Training: 6%|▌ | 110/2000 [00:41<10:59, 2.87it/s, loss=9.7371, acc=0.015, ppl=16934.1, lr=2.70e-03]\u001b[A\n",
"Training: 6%|▌ | 120/2000 [00:45<11:42, 2.68it/s, loss=9.7371, acc=0.015, ppl=16934.1, lr=2.70e-03]\u001b[A\n",
"Training: 6%|▌ | 120/2000 [00:45<11:42, 2.68it/s, loss=9.3791, acc=0.015, ppl=11838.1, lr=3.00e-03]\u001b[A\n",
"Training: 6%|▋ | 130/2000 [00:49<11:24, 2.73it/s, loss=9.3791, acc=0.015, ppl=11838.1, lr=3.00e-03]\u001b[A\n",
"Training: 6%|▋ | 130/2000 [00:49<11:24, 2.73it/s, loss=9.1484, acc=0.017, ppl=9399.6, lr=3.20e-03] \u001b[A\n",
"Training: 7%|▋ | 140/2000 [00:52<11:30, 2.69it/s, loss=9.1484, acc=0.017, ppl=9399.6, lr=3.20e-03]\u001b[A\n",
"Training: 7%|▋ | 140/2000 [00:53<11:30, 2.69it/s, loss=8.7885, acc=0.022, ppl=6558.4, lr=3.50e-03]\u001b[A\n",
"Training: 8%|▊ | 150/2000 [00:56<11:05, 2.78it/s, loss=8.7885, acc=0.022, ppl=6558.4, lr=3.50e-03]\u001b[A\n",
"Training: 8%|▊ | 150/2000 [00:56<11:05, 2.78it/s, loss=8.5138, acc=0.039, ppl=4983.0, lr=3.70e-03]\u001b[A\n",
"Training: 8%|▊ | 160/2000 [01:00<11:16, 2.72it/s, loss=8.5138, acc=0.039, ppl=4983.0, lr=3.70e-03]\u001b[A\n",
"Training: 8%|▊ | 160/2000 [01:00<11:16, 2.72it/s, loss=8.1704, acc=0.072, ppl=3534.7, lr=4.00e-03]\u001b[A\n",
"Training: 8%|▊ | 170/2000 [01:03<10:54, 2.79it/s, loss=8.1704, acc=0.072, ppl=3534.7, lr=4.00e-03]\u001b[A\n",
"Training: 8%|▊ | 170/2000 [01:03<10:54, 2.79it/s, loss=7.9453, acc=0.086, ppl=2822.2, lr=4.20e-03]\u001b[A\n",
"Training: 9%|▉ | 180/2000 [01:07<11:18, 2.68it/s, loss=7.9453, acc=0.086, ppl=2822.2, lr=4.20e-03]\u001b[A\n",
"Training: 9%|▉ | 180/2000 [01:07<11:18, 2.68it/s, loss=7.7007, acc=0.079, ppl=2209.8, lr=4.50e-03]\u001b[A\n",
"Training: 10%|▉ | 190/2000 [01:10<10:55, 2.76it/s, loss=7.7007, acc=0.079, ppl=2209.8, lr=4.50e-03]\u001b[A\n",
"Training: 10%|▉ | 190/2000 [01:11<10:55, 2.76it/s, loss=7.5115, acc=0.085, ppl=1829.0, lr=4.70e-03]\u001b[A\n",
"Training: 10%|█ | 200/2000 [01:14<11:06, 2.70it/s, loss=7.5115, acc=0.085, ppl=1829.0, lr=4.70e-03]\u001b[A\n",
"Training: 10%|█ | 200/2000 [01:15<11:06, 2.70it/s, loss=7.3088, acc=0.087, ppl=1493.3, lr=5.00e-03]\u001b[A\n",
"Training: 10%|█ | 210/2000 [01:18<10:42, 2.79it/s, loss=7.3088, acc=0.087, ppl=1493.3, lr=5.00e-03]\u001b[A\n",
"Training: 10%|█ | 210/2000 [01:18<10:42, 2.79it/s, loss=7.2170, acc=0.101, ppl=1362.4, lr=5.20e-03]\u001b[A\n",
"Training: 11%|█ | 220/2000 [01:21<10:50, 2.74it/s, loss=7.2170, acc=0.101, ppl=1362.4, lr=5.20e-03]\u001b[A\n",
"Training: 11%|█ | 220/2000 [01:22<10:50, 2.74it/s, loss=7.1611, acc=0.114, ppl=1288.4, lr=5.50e-03]\u001b[A\n",
"Training: 12%|█▏ | 230/2000 [01:25<10:27, 2.82it/s, loss=7.1611, acc=0.114, ppl=1288.4, lr=5.50e-03]\u001b[A\n",
"Training: 12%|█▏ | 230/2000 [01:25<10:27, 2.82it/s, loss=7.1061, acc=0.117, ppl=1219.4, lr=5.70e-03]\u001b[A\n",
"Training: 12%|█▏ | 240/2000 [01:28<10:34, 2.77it/s, loss=7.1061, acc=0.117, ppl=1219.4, lr=5.70e-03]\u001b[A\n",
"Training: 12%|█▏ | 240/2000 [01:29<10:34, 2.77it/s, loss=7.0066, acc=0.129, ppl=1103.9, lr=6.00e-03]\u001b[A\n",
"Training: 12%|█▎ | 250/2000 [01:32<10:13, 2.85it/s, loss=7.0066, acc=0.129, ppl=1103.9, lr=6.00e-03]\u001b[A\n",
"Training: 12%|█▎ | 250/2000 [01:32<10:13, 2.85it/s, loss=7.0200, acc=0.135, ppl=1118.8, lr=6.20e-03]\u001b[A\n",
"Training: 13%|█▎ | 260/2000 [01:35<10:21, 2.80it/s, loss=7.0200, acc=0.135, ppl=1118.8, lr=6.20e-03]\u001b[A\n",
"Training: 13%|█▎ | 260/2000 [01:36<10:21, 2.80it/s, loss=6.8509, acc=0.141, ppl=944.7, lr=6.50e-03] \u001b[A\n",
"Training: 14%|█▎ | 270/2000 [01:39<10:00, 2.88it/s, loss=6.8509, acc=0.141, ppl=944.7, lr=6.50e-03]\u001b[A\n",
"Training: 14%|█▎ | 270/2000 [01:39<10:00, 2.88it/s, loss=6.7640, acc=0.150, ppl=866.1, lr=6.70e-03]\u001b[A\n",
"Training: 14%|█▍ | 280/2000 [01:42<10:10, 2.82it/s, loss=6.7640, acc=0.150, ppl=866.1, lr=6.70e-03]\u001b[A\n",
"Training: 14%|█▍ | 280/2000 [01:43<10:10, 2.82it/s, loss=6.7645, acc=0.152, ppl=866.5, lr=7.00e-03]\u001b[A\n",
"Training: 14%|█▍ | 290/2000 [01:46<09:50, 2.89it/s, loss=6.7645, acc=0.152, ppl=866.5, lr=7.00e-03]\u001b[A\n",
"Training: 14%|█▍ | 290/2000 [01:46<09:50, 2.89it/s, loss=6.5613, acc=0.154, ppl=707.2, lr=7.20e-03]\u001b[A\n",
"Training: 15%|█▌ | 300/2000 [01:49<10:00, 2.83it/s, loss=6.5613, acc=0.154, ppl=707.2, lr=7.20e-03]\u001b[A\n",
"Training: 15%|█▌ | 300/2000 [01:50<10:00, 2.83it/s, loss=6.3311, acc=0.170, ppl=561.8, lr=7.50e-03]\u001b[A\n",
"Training: 16%|█▌ | 310/2000 [01:53<09:41, 2.90it/s, loss=6.3311, acc=0.170, ppl=561.8, lr=7.50e-03]\u001b[A\n",
"Training: 16%|█▌ | 310/2000 [01:53<09:41, 2.90it/s, loss=6.4285, acc=0.175, ppl=619.2, lr=7.70e-03]\u001b[A\n",
"Training: 16%|█▌ | 320/2000 [01:56<09:52, 2.84it/s, loss=6.4285, acc=0.175, ppl=619.2, lr=7.70e-03]\u001b[A\n",
"Training: 16%|█▌ | 320/2000 [01:57<09:52, 2.84it/s, loss=6.3875, acc=0.175, ppl=594.4, lr=8.00e-03]\u001b[A\n",
"Training: 16%|█▋ | 330/2000 [02:00<09:34, 2.91it/s, loss=6.3875, acc=0.175, ppl=594.4, lr=8.00e-03]\u001b[A\n",
"Training: 16%|█▋ | 330/2000 [02:00<09:34, 2.91it/s, loss=6.1573, acc=0.171, ppl=472.1, lr=8.20e-03]\u001b[A\n",
"Training: 17%|█▋ | 340/2000 [02:03<09:45, 2.83it/s, loss=6.1573, acc=0.171, ppl=472.1, lr=8.20e-03]\u001b[A\n",
"Training: 17%|█▋ | 340/2000 [02:04<09:45, 2.83it/s, loss=6.0802, acc=0.182, ppl=437.1, lr=8.50e-03]\u001b[A\n",
"Training: 18%|█▊ | 350/2000 [02:07<09:28, 2.90it/s, loss=6.0802, acc=0.182, ppl=437.1, lr=8.50e-03]\u001b[A\n",
"Training: 18%|█▊ | 350/2000 [02:07<09:28, 2.90it/s, loss=5.9842, acc=0.184, ppl=397.1, lr=8.70e-03]\u001b[A\n",
"Training: 18%|█▊ | 360/2000 [02:10<09:40, 2.83it/s, loss=5.9842, acc=0.184, ppl=397.1, lr=8.70e-03]\u001b[A\n",
"Training: 18%|█▊ | 360/2000 [02:11<09:40, 2.83it/s, loss=5.9918, acc=0.192, ppl=400.1, lr=9.00e-03]\u001b[A\n",
"Training: 18%|█▊ | 370/2000 [02:14<09:23, 2.89it/s, loss=5.9918, acc=0.192, ppl=400.1, lr=9.00e-03]\u001b[A\n",
"Training: 18%|█▊ | 370/2000 [02:14<09:23, 2.89it/s, loss=5.8167, acc=0.192, ppl=335.9, lr=9.20e-03]\u001b[A\n",
"Training: 19%|█▉ | 380/2000 [02:17<09:34, 2.82it/s, loss=5.8167, acc=0.192, ppl=335.9, lr=9.20e-03]\u001b[A\n",
"Training: 19%|█▉ | 380/2000 [02:18<09:34, 2.82it/s, loss=5.6587, acc=0.200, ppl=286.8, lr=9.50e-03]\u001b[A\n",
"Training: 20%|█▉ | 390/2000 [02:21<09:18, 2.88it/s, loss=5.6587, acc=0.200, ppl=286.8, lr=9.50e-03]\u001b[A\n",
"Training: 20%|█▉ | 390/2000 [02:21<09:18, 2.88it/s, loss=5.6871, acc=0.201, ppl=295.0, lr=9.70e-03]\u001b[A\n",
"Training: 20%|██ | 400/2000 [02:24<09:30, 2.81it/s, loss=5.6871, acc=0.201, ppl=295.0, lr=9.70e-03]\u001b[A\n",
"Training: 20%|██ | 400/2000 [02:25<09:30, 2.81it/s, loss=5.6464, acc=0.200, ppl=283.3, lr=1.00e-02]\u001b[A\n",
"Training: 20%|██ | 410/2000 [02:28<09:13, 2.87it/s, loss=5.6464, acc=0.200, ppl=283.3, lr=1.00e-02]\u001b[A\n",
"Training: 20%|██ | 410/2000 [02:28<09:13, 2.87it/s, loss=5.6009, acc=0.205, ppl=270.7, lr=1.00e-02]\u001b[A\n",
"Training: 21%|██ | 420/2000 [02:31<09:24, 2.80it/s, loss=5.6009, acc=0.205, ppl=270.7, lr=1.00e-02]\u001b[A\n",
"Training: 21%|██ | 420/2000 [02:32<09:24, 2.80it/s, loss=5.6110, acc=0.205, ppl=273.4, lr=1.00e-02]\u001b[A\n",
"Training: 22%|██▏ | 430/2000 [02:35<09:08, 2.86it/s, loss=5.6110, acc=0.205, ppl=273.4, lr=1.00e-02]\u001b[A\n",
"Training: 22%|██▏ | 430/2000 [02:35<09:08, 2.86it/s, loss=5.5572, acc=0.207, ppl=259.1, lr=1.00e-02]\u001b[A\n",
"Training: 22%|██▏ | 440/2000 [02:39<09:18, 2.79it/s, loss=5.5572, acc=0.207, ppl=259.1, lr=1.00e-02]\u001b[A\n",
"Training: 22%|██▏ | 440/2000 [02:39<09:18, 2.79it/s, loss=5.4647, acc=0.206, ppl=236.2, lr=1.00e-02]\u001b[A\n",
"Training: 22%|██▎ | 450/2000 [02:42<09:02, 2.86it/s, loss=5.4647, acc=0.206, ppl=236.2, lr=1.00e-02]\u001b[A\n",
"Training: 22%|██▎ | 450/2000 [02:42<09:02, 2.86it/s, loss=5.3177, acc=0.211, ppl=203.9, lr=1.00e-02]\u001b[A\n",
"Training: 23%|██▎ | 460/2000 [02:46<09:12, 2.79it/s, loss=5.3177, acc=0.211, ppl=203.9, lr=1.00e-02]\u001b[A\n",
"Training: 23%|██▎ | 460/2000 [02:46<09:12, 2.79it/s, loss=5.3506, acc=0.212, ppl=210.7, lr=1.00e-02]\u001b[A\n",
"Training: 24%|██▎ | 470/2000 [02:49<08:54, 2.86it/s, loss=5.3506, acc=0.212, ppl=210.7, lr=1.00e-02]\u001b[A\n",
"Training: 24%|██▎ | 470/2000 [02:49<08:54, 2.86it/s, loss=5.1576, acc=0.222, ppl=173.7, lr=1.00e-02]\u001b[A\n",
"Training: 24%|██▍ | 480/2000 [02:53<09:08, 2.77it/s, loss=5.1576, acc=0.222, ppl=173.7, lr=1.00e-02]\u001b[A\n",
"Training: 24%|██▍ | 480/2000 [02:53<09:08, 2.77it/s, loss=5.1177, acc=0.226, ppl=166.9, lr=1.00e-02]\u001b[A\n",
"Training: 24%|██▍ | 490/2000 [02:56<08:53, 2.83it/s, loss=5.1177, acc=0.226, ppl=166.9, lr=1.00e-02]\u001b[A\n",
"Training: 24%|██▍ | 490/2000 [02:57<08:53, 2.83it/s, loss=5.0815, acc=0.241, ppl=161.0, lr=1.00e-02]\u001b[A\n",
"Training: 25%|██▌ | 500/2000 [03:00<08:59, 2.78it/s, loss=5.0815, acc=0.241, ppl=161.0, lr=1.00e-02]\u001b[A\n",
"Training: 25%|██▌ | 500/2000 [03:00<08:59, 2.78it/s, loss=5.2135, acc=0.221, ppl=183.7, lr=1.00e-02]\u001b[A"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Step 500: Val Loss: 4.9942, Val Acc: 0.2329, Val PPL: 147.55\n",
"💾 Saved best model with val_loss: 4.9942\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n",
"Training: 26%|██▌ | 510/2000 [03:16<18:01, 1.38it/s, loss=5.2135, acc=0.221, ppl=183.7, lr=1.00e-02]\u001b[A\n",
"Training: 26%|██▌ | 510/2000 [03:16<18:01, 1.38it/s, loss=5.1443, acc=0.217, ppl=171.5, lr=1.00e-02]\u001b[A\n",
"Training: 26%|██▌ | 520/2000 [03:19<15:18, 1.61it/s, loss=5.1443, acc=0.217, ppl=171.5, lr=1.00e-02]\u001b[A\n",
"Training: 26%|██▌ | 520/2000 [03:20<15:18, 1.61it/s, loss=4.9569, acc=0.231, ppl=142.2, lr=9.99e-03]\u001b[A\n",
"Training: 26%|██▋ | 530/2000 [03:23<13:02, 1.88it/s, loss=4.9569, acc=0.231, ppl=142.2, lr=9.99e-03]\u001b[A\n",
"Training: 26%|██▋ | 530/2000 [03:23<13:02, 1.88it/s, loss=4.9853, acc=0.238, ppl=146.3, lr=9.99e-03]\u001b[A\n",
"Training: 27%|██▋ | 540/2000 [03:26<11:47, 2.06it/s, loss=4.9853, acc=0.238, ppl=146.3, lr=9.99e-03]\u001b[A\n",
"Training: 27%|██▋ | 540/2000 [03:27<11:47, 2.06it/s, loss=4.9170, acc=0.242, ppl=136.6, lr=9.99e-03]\u001b[A\n",
"Training: 28%|██▊ | 550/2000 [03:30<10:33, 2.29it/s, loss=4.9170, acc=0.242, ppl=136.6, lr=9.99e-03]\u001b[A\n",
"Training: 28%|██▊ | 550/2000 [03:30<10:33, 2.29it/s, loss=4.8744, acc=0.235, ppl=130.9, lr=9.99e-03]\u001b[A\n",
"Training: 28%|██▊ | 560/2000 [03:33<10:02, 2.39it/s, loss=4.8744, acc=0.235, ppl=130.9, lr=9.99e-03]\u001b[A\n",
"Training: 28%|██▊ | 560/2000 [03:34<10:02, 2.39it/s, loss=4.8439, acc=0.236, ppl=127.0, lr=9.99e-03]\u001b[A\n",
"Training: 28%|██▊ | 570/2000 [03:37<09:19, 2.55it/s, loss=4.8439, acc=0.236, ppl=127.0, lr=9.99e-03]\u001b[A\n",
"Training: 28%|██▊ | 570/2000 [03:37<09:19, 2.55it/s, loss=4.8594, acc=0.228, ppl=128.9, lr=9.99e-03]\u001b[A\n",
"Training: 29%|██▉ | 580/2000 [03:40<09:08, 2.59it/s, loss=4.8594, acc=0.228, ppl=128.9, lr=9.99e-03]\u001b[A\n",
"Training: 29%|██▉ | 580/2000 [03:41<09:08, 2.59it/s, loss=4.7714, acc=0.237, ppl=118.1, lr=9.99e-03]\u001b[A\n",
"Training: 30%|██▉ | 590/2000 [03:44<08:39, 2.71it/s, loss=4.7714, acc=0.237, ppl=118.1, lr=9.99e-03]\u001b[A\n",
"Training: 30%|██▉ | 590/2000 [03:44<08:39, 2.71it/s, loss=4.8172, acc=0.241, ppl=123.6, lr=9.99e-03]\u001b[A\n",
"Training: 30%|███ | 600/2000 [03:48<08:39, 2.70it/s, loss=4.8172, acc=0.241, ppl=123.6, lr=9.99e-03]\u001b[A\n",
"Training: 30%|███ | 600/2000 [03:48<08:39, 2.70it/s, loss=4.7483, acc=0.253, ppl=115.4, lr=9.98e-03]\u001b[A\n",
"Training: 30%|███ | 610/2000 [03:51<08:17, 2.80it/s, loss=4.7483, acc=0.253, ppl=115.4, lr=9.98e-03]\u001b[A\n",
"Training: 30%|███ | 610/2000 [03:51<08:17, 2.80it/s, loss=4.6863, acc=0.251, ppl=108.4, lr=9.98e-03]\u001b[A\n",
"Training: 31%|███ | 620/2000 [03:55<08:21, 2.75it/s, loss=4.6863, acc=0.251, ppl=108.4, lr=9.98e-03]\u001b[A\n",
"Training: 31%|███ | 620/2000 [03:55<08:21, 2.75it/s, loss=4.7513, acc=0.249, ppl=115.7, lr=9.98e-03]\u001b[A\n",
"Training: 32%|███▏ | 630/2000 [03:58<08:03, 2.83it/s, loss=4.7513, acc=0.249, ppl=115.7, lr=9.98e-03]\u001b[A\n",
"Training: 32%|███▏ | 630/2000 [03:58<08:03, 2.83it/s, loss=4.7481, acc=0.240, ppl=115.4, lr=9.98e-03]\u001b[A\n",
"Training: 32%|███▏ | 640/2000 [04:02<08:09, 2.78it/s, loss=4.7481, acc=0.240, ppl=115.4, lr=9.98e-03]\u001b[A\n",
"Training: 32%|███▏ | 640/2000 [04:02<08:09, 2.78it/s, loss=4.6185, acc=0.248, ppl=101.3, lr=9.98e-03]\u001b[A\n",
"Training: 32%|███▎ | 650/2000 [04:05<07:53, 2.85it/s, loss=4.6185, acc=0.248, ppl=101.3, lr=9.98e-03]\u001b[A\n",
"Training: 32%|███▎ | 650/2000 [04:05<07:53, 2.85it/s, loss=4.6697, acc=0.245, ppl=106.7, lr=9.98e-03]\u001b[A\n",
"Training: 33%|███▎ | 660/2000 [04:09<07:59, 2.79it/s, loss=4.6697, acc=0.245, ppl=106.7, lr=9.98e-03]\u001b[A\n",
"Training: 33%|███▎ | 660/2000 [04:09<07:59, 2.79it/s, loss=4.6468, acc=0.257, ppl=104.3, lr=9.97e-03]\u001b[A\n",
"Training: 34%|███▎ | 670/2000 [04:12<07:44, 2.87it/s, loss=4.6468, acc=0.257, ppl=104.3, lr=9.97e-03]\u001b[A\n",
"Training: 34%|███▎ | 670/2000 [04:12<07:44, 2.87it/s, loss=4.4939, acc=0.265, ppl=89.5, lr=9.97e-03] \u001b[A\n",
"Training: 34%|███▍ | 680/2000 [04:16<07:51, 2.80it/s, loss=4.4939, acc=0.265, ppl=89.5, lr=9.97e-03]\u001b[A\n",
"Training: 34%|███▍ | 680/2000 [04:16<07:51, 2.80it/s, loss=4.4323, acc=0.262, ppl=84.1, lr=9.97e-03]\u001b[A\n",
"Training: 34%|███▍ | 690/2000 [04:19<07:35, 2.87it/s, loss=4.4323, acc=0.262, ppl=84.1, lr=9.97e-03]\u001b[A\n",
"Training: 34%|███▍ | 690/2000 [04:19<07:35, 2.87it/s, loss=4.5176, acc=0.260, ppl=91.6, lr=9.97e-03]\u001b[A\n",
"Training: 35%|███▌ | 700/2000 [04:23<07:43, 2.81it/s, loss=4.5176, acc=0.260, ppl=91.6, lr=9.97e-03]\u001b[A\n",
"Training: 35%|███▌ | 700/2000 [04:23<07:43, 2.81it/s, loss=4.3477, acc=0.275, ppl=77.3, lr=9.97e-03]\u001b[A\n",
"Training: 36%|███▌ | 710/2000 [04:26<07:28, 2.88it/s, loss=4.3477, acc=0.275, ppl=77.3, lr=9.97e-03]\u001b[A\n",
"Training: 36%|███▌ | 710/2000 [04:27<07:28, 2.88it/s, loss=4.4320, acc=0.274, ppl=84.1, lr=9.96e-03]\u001b[A\n",
"Training: 36%|███▌ | 720/2000 [04:30<07:35, 2.81it/s, loss=4.4320, acc=0.274, ppl=84.1, lr=9.96e-03]\u001b[A\n",
"Training: 36%|███▌ | 720/2000 [04:30<07:35, 2.81it/s, loss=4.2559, acc=0.279, ppl=70.5, lr=9.96e-03]\u001b[A\n",
"Training: 36%|███▋ | 730/2000 [04:33<07:21, 2.88it/s, loss=4.2559, acc=0.279, ppl=70.5, lr=9.96e-03]\u001b[A\n",
"Training: 36%|███▋ | 730/2000 [04:34<07:21, 2.88it/s, loss=4.4661, acc=0.258, ppl=87.0, lr=9.96e-03]\u001b[A\n",
"Training: 37%|███▋ | 740/2000 [04:37<07:28, 2.81it/s, loss=4.4661, acc=0.258, ppl=87.0, lr=9.96e-03]\u001b[A\n",
"Training: 37%|███▋ | 740/2000 [04:37<07:28, 2.81it/s, loss=4.3179, acc=0.269, ppl=75.0, lr=9.96e-03]\u001b[A\n",
"Training: 38%|███▊ | 750/2000 [04:40<07:13, 2.88it/s, loss=4.3179, acc=0.269, ppl=75.0, lr=9.96e-03]\u001b[A\n",
"Training: 38%|███▊ | 750/2000 [04:41<07:13, 2.88it/s, loss=4.1885, acc=0.295, ppl=65.9, lr=9.95e-03]\u001b[A\n",
"Training: 38%|███▊ | 760/2000 [04:44<07:20, 2.81it/s, loss=4.1885, acc=0.295, ppl=65.9, lr=9.95e-03]\u001b[A\n",
"Training: 38%|███▊ | 760/2000 [04:44<07:20, 2.81it/s, loss=4.3201, acc=0.272, ppl=75.2, lr=9.95e-03]\u001b[A\n",
"Training: 38%|███▊ | 770/2000 [04:47<07:06, 2.89it/s, loss=4.3201, acc=0.272, ppl=75.2, lr=9.95e-03]\u001b[A\n",
"Training: 38%|███▊ | 770/2000 [04:48<07:06, 2.89it/s, loss=4.1724, acc=0.289, ppl=64.9, lr=9.95e-03]\u001b[A\n",
"Training: 39%|███▉ | 780/2000 [04:51<07:13, 2.81it/s, loss=4.1724, acc=0.289, ppl=64.9, lr=9.95e-03]\u001b[A\n",
"Training: 39%|███▉ | 780/2000 [04:51<07:13, 2.81it/s, loss=4.1457, acc=0.281, ppl=63.2, lr=9.94e-03]\u001b[A\n",
"Training: 40%|███▉ | 790/2000 [04:54<06:59, 2.88it/s, loss=4.1457, acc=0.281, ppl=63.2, lr=9.94e-03]\u001b[A\n",
"Training: 40%|███▉ | 790/2000 [04:55<06:59, 2.88it/s, loss=4.1339, acc=0.294, ppl=62.4, lr=9.94e-03]\u001b[A\n",
"Training: 40%|████ | 800/2000 [04:58<07:06, 2.81it/s, loss=4.1339, acc=0.294, ppl=62.4, lr=9.94e-03]\u001b[A\n",
"Training: 40%|████ | 800/2000 [04:58<07:06, 2.81it/s, loss=4.2408, acc=0.274, ppl=69.5, lr=9.94e-03]\u001b[A\n",
"Training: 40%|████ | 810/2000 [05:01<06:52, 2.88it/s, loss=4.2408, acc=0.274, ppl=69.5, lr=9.94e-03]\u001b[A\n",
"Training: 40%|████ | 810/2000 [05:02<06:52, 2.88it/s, loss=4.1348, acc=0.293, ppl=62.5, lr=9.94e-03]\u001b[A\n",
"Training: 41%|████ | 820/2000 [05:05<06:59, 2.81it/s, loss=4.1348, acc=0.293, ppl=62.5, lr=9.94e-03]\u001b[A\n",
"Training: 41%|████ | 820/2000 [05:05<06:59, 2.81it/s, loss=4.0363, acc=0.305, ppl=56.6, lr=9.93e-03]\u001b[A\n",
"Training: 42%|████▏ | 830/2000 [05:08<06:45, 2.88it/s, loss=4.0363, acc=0.305, ppl=56.6, lr=9.93e-03]\u001b[A\n",
"Training: 42%|████▏ | 830/2000 [05:09<06:45, 2.88it/s, loss=4.0560, acc=0.299, ppl=57.7, lr=9.93e-03]\u001b[A\n",
"Training: 42%|████▏ | 840/2000 [05:12<06:51, 2.82it/s, loss=4.0560, acc=0.299, ppl=57.7, lr=9.93e-03]\u001b[A\n",
"Training: 42%|████▏ | 840/2000 [05:12<06:51, 2.82it/s, loss=3.9714, acc=0.309, ppl=53.1, lr=9.93e-03]\u001b[A\n",
"Training: 42%|████▎ | 850/2000 [05:15<06:38, 2.88it/s, loss=3.9714, acc=0.309, ppl=53.1, lr=9.93e-03]\u001b[A\n",
"Training: 42%|████▎ | 850/2000 [05:16<06:38, 2.88it/s, loss=4.0104, acc=0.299, ppl=55.2, lr=9.92e-03]\u001b[A\n",
"Training: 43%|████▎ | 860/2000 [05:19<06:45, 2.81it/s, loss=4.0104, acc=0.299, ppl=55.2, lr=9.92e-03]\u001b[A\n",
"Training: 43%|████▎ | 860/2000 [05:19<06:45, 2.81it/s, loss=3.9202, acc=0.309, ppl=50.4, lr=9.92e-03]\u001b[A\n",
"Training: 44%|████▎ | 870/2000 [05:22<06:32, 2.88it/s, loss=3.9202, acc=0.309, ppl=50.4, lr=9.92e-03]\u001b[A\n",
"Training: 44%|████▎ | 870/2000 [05:23<06:32, 2.88it/s, loss=3.9117, acc=0.306, ppl=50.0, lr=9.92e-03]\u001b[A\n",
"Training: 44%|████▍ | 880/2000 [05:26<06:38, 2.81it/s, loss=3.9117, acc=0.306, ppl=50.0, lr=9.92e-03]\u001b[A\n",
"Training: 44%|████▍ | 880/2000 [05:26<06:38, 2.81it/s, loss=3.7061, acc=0.340, ppl=40.7, lr=9.91e-03]\u001b[A\n",
"Training: 44%|████▍ | 890/2000 [05:29<06:24, 2.89it/s, loss=3.7061, acc=0.340, ppl=40.7, lr=9.91e-03]\u001b[A\n",
"Training: 44%|████▍ | 890/2000 [05:30<06:24, 2.89it/s, loss=3.8237, acc=0.323, ppl=45.8, lr=9.91e-03]\u001b[A\n",
"Training: 45%|████▌ | 900/2000 [05:33<06:30, 2.82it/s, loss=3.8237, acc=0.323, ppl=45.8, lr=9.91e-03]\u001b[A\n",
"Training: 45%|████▌ | 900/2000 [05:33<06:30, 2.82it/s, loss=3.7239, acc=0.327, ppl=41.4, lr=9.90e-03]\u001b[A\n",
"Training: 46%|████▌ | 910/2000 [05:36<06:17, 2.88it/s, loss=3.7239, acc=0.327, ppl=41.4, lr=9.90e-03]\u001b[A\n",
"Training: 46%|████▌ | 910/2000 [05:37<06:17, 2.88it/s, loss=3.9225, acc=0.315, ppl=50.5, lr=9.90e-03]\u001b[A\n",
"Training: 46%|████▌ | 920/2000 [05:40<06:23, 2.81it/s, loss=3.9225, acc=0.315, ppl=50.5, lr=9.90e-03]\u001b[A\n",
"Training: 46%|████▌ | 920/2000 [05:40<06:23, 2.81it/s, loss=3.6266, acc=0.347, ppl=37.6, lr=9.90e-03]\u001b[A\n",
"Training: 46%|████▋ | 930/2000 [05:43<06:11, 2.88it/s, loss=3.6266, acc=0.347, ppl=37.6, lr=9.90e-03]\u001b[A\n",
"Training: 46%|████▋ | 930/2000 [05:44<06:11, 2.88it/s, loss=3.8791, acc=0.306, ppl=48.4, lr=9.89e-03]\u001b[A\n",
"Training: 47%|████▋ | 940/2000 [05:47<06:16, 2.82it/s, loss=3.8791, acc=0.306, ppl=48.4, lr=9.89e-03]\u001b[A\n",
"Training: 47%|████▋ | 940/2000 [05:47<06:16, 2.82it/s, loss=3.7618, acc=0.326, ppl=43.0, lr=9.89e-03]\u001b[A\n",
"Training: 48%|████▊ | 950/2000 [05:50<06:04, 2.88it/s, loss=3.7618, acc=0.326, ppl=43.0, lr=9.89e-03]\u001b[A\n",
"Training: 48%|████▊ | 950/2000 [05:51<06:04, 2.88it/s, loss=3.7082, acc=0.335, ppl=40.8, lr=9.89e-03]\u001b[A\n",
"Training: 48%|████▊ | 960/2000 [05:54<06:09, 2.82it/s, loss=3.7082, acc=0.335, ppl=40.8, lr=9.89e-03]\u001b[A\n",
"Training: 48%|████▊ | 960/2000 [05:54<06:09, 2.82it/s, loss=3.6027, acc=0.345, ppl=36.7, lr=9.88e-03]\u001b[A\n",
"Training: 48%|████▊ | 970/2000 [05:57<05:56, 2.89it/s, loss=3.6027, acc=0.345, ppl=36.7, lr=9.88e-03]\u001b[A\n",
"Training: 48%|████▊ | 970/2000 [05:58<05:56, 2.89it/s, loss=3.4623, acc=0.358, ppl=31.9, lr=9.88e-03]\u001b[A\n",
"Training: 49%|████▉ | 980/2000 [06:01<06:01, 2.82it/s, loss=3.4623, acc=0.358, ppl=31.9, lr=9.88e-03]\u001b[A\n",
"Training: 49%|████▉ | 980/2000 [06:01<06:01, 2.82it/s, loss=3.5815, acc=0.357, ppl=35.9, lr=9.87e-03]\u001b[A\n",
"Training: 50%|████▉ | 990/2000 [06:04<05:50, 2.89it/s, loss=3.5815, acc=0.357, ppl=35.9, lr=9.87e-03]\u001b[A\n",
"Training: 50%|████▉ | 990/2000 [06:05<05:50, 2.89it/s, loss=3.4841, acc=0.363, ppl=32.6, lr=9.87e-03]\u001b[A\n",
"Training: 50%|█████ | 1000/2000 [06:08<05:55, 2.82it/s, loss=3.4841, acc=0.363, ppl=32.6, lr=9.87e-03]\u001b[A\n",
"Training: 50%|█████ | 1000/2000 [06:08<05:55, 2.82it/s, loss=3.5973, acc=0.356, ppl=36.5, lr=9.86e-03]\u001b[A"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Step 1000: Val Loss: 3.1844, Val Acc: 0.4082, Val PPL: 24.15\n",
"💾 Saved best model with val_loss: 3.1844\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n",
"Training: 50%|█████ | 1010/2000 [06:24<11:57, 1.38it/s, loss=3.5973, acc=0.356, ppl=36.5, lr=9.86e-03]\u001b[A\n",
"Training: 50%|█████ | 1010/2000 [06:24<11:57, 1.38it/s, loss=3.4115, acc=0.375, ppl=30.3, lr=9.86e-03]\u001b[A\n",
"Training: 51%|█████ | 1020/2000 [06:28<10:07, 1.61it/s, loss=3.4115, acc=0.375, ppl=30.3, lr=9.86e-03]\u001b[A\n",
"Training: 51%|█████ | 1020/2000 [06:28<10:07, 1.61it/s, loss=3.2510, acc=0.399, ppl=25.8, lr=9.85e-03]\u001b[A\n",
"Training: 52%|█████▏ | 1030/2000 [06:31<08:35, 1.88it/s, loss=3.2510, acc=0.399, ppl=25.8, lr=9.85e-03]\u001b[A\n",
"Training: 52%|█████▏ | 1030/2000 [06:31<08:35, 1.88it/s, loss=3.4950, acc=0.364, ppl=33.0, lr=9.85e-03]\u001b[A\n",
"Training: 52%|█████▏ | 1040/2000 [06:35<07:45, 2.06it/s, loss=3.4950, acc=0.364, ppl=33.0, lr=9.85e-03]\u001b[A\n",
"Training: 52%|█████▏ | 1040/2000 [06:35<07:45, 2.06it/s, loss=3.5003, acc=0.365, ppl=33.1, lr=9.84e-03]\u001b[A\n",
"Training: 52%|█████▎ | 1050/2000 [06:38<06:55, 2.29it/s, loss=3.5003, acc=0.365, ppl=33.1, lr=9.84e-03]\u001b[A\n",
"Training: 52%|█████▎ | 1050/2000 [06:38<06:55, 2.29it/s, loss=3.3385, acc=0.383, ppl=28.2, lr=9.84e-03]\u001b[A\n",
"Training: 53%|█████▎ | 1060/2000 [06:42<06:32, 2.39it/s, loss=3.3385, acc=0.383, ppl=28.2, lr=9.84e-03]\u001b[A\n",
"Training: 53%|█████▎ | 1060/2000 [06:42<06:32, 2.39it/s, loss=3.4826, acc=0.363, ppl=32.5, lr=9.83e-03]\u001b[A\n",
"Training: 54%|█████▎ | 1070/2000 [06:45<06:03, 2.56it/s, loss=3.4826, acc=0.363, ppl=32.5, lr=9.83e-03]\u001b[A\n",
"Training: 54%|█████▎ | 1070/2000 [06:45<06:03, 2.56it/s, loss=3.4525, acc=0.361, ppl=31.6, lr=9.83e-03]\u001b[A\n",
"Training: 54%|█████▍ | 1080/2000 [06:49<05:54, 2.59it/s, loss=3.4525, acc=0.361, ppl=31.6, lr=9.83e-03]\u001b[A\n",
"Training: 54%|█████▍ | 1080/2000 [06:49<05:54, 2.59it/s, loss=3.3291, acc=0.374, ppl=27.9, lr=9.82e-03]\u001b[A\n",
"Training: 55%|█████▍ | 1090/2000 [06:52<05:35, 2.72it/s, loss=3.3291, acc=0.374, ppl=27.9, lr=9.82e-03]\u001b[A\n",
"Training: 55%|█████▍ | 1090/2000 [06:52<05:35, 2.72it/s, loss=3.2726, acc=0.390, ppl=26.4, lr=9.82e-03]\u001b[A\n",
"Training: 55%|█████▌ | 1100/2000 [06:56<05:33, 2.70it/s, loss=3.2726, acc=0.390, ppl=26.4, lr=9.82e-03]\u001b[A\n",
"Training: 55%|█████▌ | 1100/2000 [06:56<05:33, 2.70it/s, loss=3.1414, acc=0.409, ppl=23.1, lr=9.81e-03]\u001b[A\n",
"Training: 56%|█████▌ | 1110/2000 [06:59<05:17, 2.80it/s, loss=3.1414, acc=0.409, ppl=23.1, lr=9.81e-03]\u001b[A\n",
"Training: 56%|█████▌ | 1110/2000 [06:59<05:17, 2.80it/s, loss=3.2345, acc=0.389, ppl=25.4, lr=9.81e-03]\u001b[A\n",
"Training: 56%|█████▌ | 1120/2000 [07:03<05:18, 2.76it/s, loss=3.2345, acc=0.389, ppl=25.4, lr=9.81e-03]\u001b[A\n",
"Training: 56%|█████▌ | 1120/2000 [07:03<05:18, 2.76it/s, loss=3.4019, acc=0.366, ppl=30.0, lr=9.80e-03]\u001b[A\n",
"Training: 56%|█████▋ | 1130/2000 [07:06<05:06, 2.84it/s, loss=3.4019, acc=0.366, ppl=30.0, lr=9.80e-03]\u001b[A\n",
"Training: 56%|█████▋ | 1130/2000 [07:06<05:06, 2.84it/s, loss=3.1875, acc=0.404, ppl=24.2, lr=9.80e-03]\u001b[A\n",
"Training: 57%|█████▋ | 1140/2000 [07:10<05:08, 2.79it/s, loss=3.1875, acc=0.404, ppl=24.2, lr=9.80e-03]\u001b[A\n",
"Training: 57%|█████▋ | 1140/2000 [07:10<05:08, 2.79it/s, loss=3.2513, acc=0.395, ppl=25.8, lr=9.79e-03]\u001b[A\n",
"Training: 57%|█████▊ | 1150/2000 [07:13<04:56, 2.86it/s, loss=3.2513, acc=0.395, ppl=25.8, lr=9.79e-03]\u001b[A\n",
"Training: 57%|█████▊ | 1150/2000 [07:13<04:56, 2.86it/s, loss=3.3684, acc=0.367, ppl=29.0, lr=9.79e-03]\u001b[A\n",
"Training: 58%|█████▊ | 1160/2000 [07:17<04:59, 2.80it/s, loss=3.3684, acc=0.367, ppl=29.0, lr=9.79e-03]\u001b[A\n",
"Training: 58%|█████▊ | 1160/2000 [07:17<04:59, 2.80it/s, loss=3.0060, acc=0.432, ppl=20.2, lr=9.78e-03]\u001b[A\n",
"Training: 58%|█████▊ | 1170/2000 [07:20<04:48, 2.88it/s, loss=3.0060, acc=0.432, ppl=20.2, lr=9.78e-03]\u001b[A\n",
"Training: 58%|█████▊ | 1170/2000 [07:20<04:48, 2.88it/s, loss=3.0221, acc=0.419, ppl=20.5, lr=9.78e-03]\u001b[A\n",
"Training: 59%|█████▉ | 1180/2000 [07:24<04:51, 2.81it/s, loss=3.0221, acc=0.419, ppl=20.5, lr=9.78e-03]\u001b[A\n",
"Training: 59%|█████▉ | 1180/2000 [07:24<04:51, 2.81it/s, loss=2.9939, acc=0.422, ppl=20.0, lr=9.77e-03]\u001b[A\n",
"Training: 60%|█████▉ | 1190/2000 [07:27<04:41, 2.88it/s, loss=2.9939, acc=0.422, ppl=20.0, lr=9.77e-03]\u001b[A\n",
"Training: 60%|█████▉ | 1190/2000 [07:27<04:41, 2.88it/s, loss=3.1847, acc=0.403, ppl=24.2, lr=9.76e-03]\u001b[A\n",
"Training: 60%|██████ | 1200/2000 [07:31<04:44, 2.81it/s, loss=3.1847, acc=0.403, ppl=24.2, lr=9.76e-03]\u001b[A\n",
"Training: 60%|██████ | 1200/2000 [07:31<04:44, 2.81it/s, loss=3.0391, acc=0.417, ppl=20.9, lr=9.76e-03]\u001b[A\n",
"Training: 60%|██████ | 1210/2000 [07:34<04:33, 2.89it/s, loss=3.0391, acc=0.417, ppl=20.9, lr=9.76e-03]\u001b[A\n",
"Training: 60%|██████ | 1210/2000 [07:35<04:33, 2.89it/s, loss=2.9365, acc=0.431, ppl=18.8, lr=9.75e-03]\u001b[A\n",
"Training: 61%|██████ | 1220/2000 [07:38<04:36, 2.82it/s, loss=2.9365, acc=0.431, ppl=18.8, lr=9.75e-03]\u001b[A\n",
"Training: 61%|██████ | 1220/2000 [07:38<04:36, 2.82it/s, loss=3.0035, acc=0.427, ppl=20.2, lr=9.74e-03]\u001b[A\n",
"Training: 62%|██████▏ | 1230/2000 [07:41<04:27, 2.88it/s, loss=3.0035, acc=0.427, ppl=20.2, lr=9.74e-03]\u001b[A\n",
"Training: 62%|██████▏ | 1230/2000 [07:42<04:27, 2.88it/s, loss=3.1786, acc=0.401, ppl=24.0, lr=9.74e-03]\u001b[A\n",
"Training: 62%|██████▏ | 1240/2000 [07:45<04:29, 2.82it/s, loss=3.1786, acc=0.401, ppl=24.0, lr=9.74e-03]\u001b[A\n",
"Training: 62%|██████▏ | 1240/2000 [07:45<04:29, 2.82it/s, loss=3.0778, acc=0.412, ppl=21.7, lr=9.73e-03]\u001b[A\n",
"Training: 62%|██████▎ | 1250/2000 [07:48<04:19, 2.89it/s, loss=3.0778, acc=0.412, ppl=21.7, lr=9.73e-03]\u001b[A\n",
"Training: 62%|██████▎ | 1250/2000 [07:49<04:19, 2.89it/s, loss=2.9591, acc=0.421, ppl=19.3, lr=9.73e-03]\u001b[A\n",
"Training: 63%|██████▎ | 1260/2000 [07:52<04:22, 2.82it/s, loss=2.9591, acc=0.421, ppl=19.3, lr=9.73e-03]\u001b[A\n",
"Training: 63%|██████▎ | 1260/2000 [07:52<04:22, 2.82it/s, loss=2.9366, acc=0.440, ppl=18.9, lr=9.72e-03]\u001b[A\n",
"Training: 64%|██████▎ | 1270/2000 [07:55<04:12, 2.89it/s, loss=2.9366, acc=0.440, ppl=18.9, lr=9.72e-03]\u001b[A\n",
"Training: 64%|██████▎ | 1270/2000 [07:56<04:12, 2.89it/s, loss=2.8597, acc=0.447, ppl=17.5, lr=9.71e-03]\u001b[A\n",
"Training: 64%|██████▍ | 1280/2000 [07:59<04:15, 2.82it/s, loss=2.8597, acc=0.447, ppl=17.5, lr=9.71e-03]\u001b[A\n",
"Training: 64%|██████▍ | 1280/2000 [07:59<04:15, 2.82it/s, loss=2.9449, acc=0.426, ppl=19.0, lr=9.71e-03]\u001b[A\n",
"Training: 64%|██████▍ | 1290/2000 [08:02<04:05, 2.89it/s, loss=2.9449, acc=0.426, ppl=19.0, lr=9.71e-03]\u001b[A\n",
"Training: 64%|██████▍ | 1290/2000 [08:03<04:05, 2.89it/s, loss=2.8927, acc=0.441, ppl=18.0, lr=9.70e-03]\u001b[A\n",
"Training: 65%|██████▌ | 1300/2000 [08:06<04:08, 2.82it/s, loss=2.8927, acc=0.441, ppl=18.0, lr=9.70e-03]\u001b[A\n",
"Training: 65%|██████▌ | 1300/2000 [08:06<04:08, 2.82it/s, loss=2.8267, acc=0.447, ppl=16.9, lr=9.69e-03]\u001b[A\n",
"Training: 66%|██████▌ | 1310/2000 [08:09<03:58, 2.89it/s, loss=2.8267, acc=0.447, ppl=16.9, lr=9.69e-03]\u001b[A\n",
"Training: 66%|██████▌ | 1310/2000 [08:10<03:58, 2.89it/s, loss=2.7897, acc=0.454, ppl=16.3, lr=9.69e-03]\u001b[A\n",
"Training: 66%|██████▌ | 1320/2000 [08:13<04:00, 2.82it/s, loss=2.7897, acc=0.454, ppl=16.3, lr=9.69e-03]\u001b[A\n",
"Training: 66%|██████▌ | 1320/2000 [08:13<04:00, 2.82it/s, loss=2.8305, acc=0.445, ppl=17.0, lr=9.68e-03]\u001b[A\n",
"Training: 66%|██████▋ | 1330/2000 [08:16<03:51, 2.89it/s, loss=2.8305, acc=0.445, ppl=17.0, lr=9.68e-03]\u001b[A\n",
"Training: 66%|██████▋ | 1330/2000 [08:17<03:51, 2.89it/s, loss=2.7801, acc=0.455, ppl=16.1, lr=9.67e-03]\u001b[A\n",
"Training: 67%|██████▋ | 1340/2000 [08:20<03:53, 2.82it/s, loss=2.7801, acc=0.455, ppl=16.1, lr=9.67e-03]\u001b[A\n",
"Training: 67%|██████▋ | 1340/2000 [08:20<03:53, 2.82it/s, loss=2.7623, acc=0.457, ppl=15.8, lr=9.66e-03]\u001b[A\n",
"Training: 68%|██████▊ | 1350/2000 [08:23<03:44, 2.89it/s, loss=2.7623, acc=0.457, ppl=15.8, lr=9.66e-03]\u001b[A\n",
"Training: 68%|██████▊ | 1350/2000 [08:24<03:44, 2.89it/s, loss=2.6817, acc=0.471, ppl=14.6, lr=9.66e-03]\u001b[A\n",
"Training: 68%|██████▊ | 1360/2000 [08:27<03:46, 2.82it/s, loss=2.6817, acc=0.471, ppl=14.6, lr=9.66e-03]\u001b[A\n",
"Training: 68%|██████▊ | 1360/2000 [08:27<03:46, 2.82it/s, loss=2.7636, acc=0.464, ppl=15.9, lr=9.65e-03]\u001b[A\n",
"Training: 68%|██████▊ | 1370/2000 [08:30<03:37, 2.89it/s, loss=2.7636, acc=0.464, ppl=15.9, lr=9.65e-03]\u001b[A\n",
"Training: 68%|██████▊ | 1370/2000 [08:31<03:37, 2.89it/s, loss=2.6714, acc=0.472, ppl=14.5, lr=9.64e-03]\u001b[A\n",
"Training: 69%|██████▉ | 1380/2000 [08:34<03:39, 2.82it/s, loss=2.6714, acc=0.472, ppl=14.5, lr=9.64e-03]\u001b[A\n",
"Training: 69%|██████▉ | 1380/2000 [08:34<03:39, 2.82it/s, loss=2.6967, acc=0.468, ppl=14.8, lr=9.64e-03]\u001b[A\n",
"Training: 70%|██████▉ | 1390/2000 [08:37<03:30, 2.90it/s, loss=2.6967, acc=0.468, ppl=14.8, lr=9.64e-03]\u001b[A\n",
"Training: 70%|██████▉ | 1390/2000 [08:38<03:30, 2.90it/s, loss=2.5589, acc=0.488, ppl=12.9, lr=9.63e-03]\u001b[A\n",
"Training: 70%|███████ | 1400/2000 [08:41<03:32, 2.83it/s, loss=2.5589, acc=0.488, ppl=12.9, lr=9.63e-03]\u001b[A\n",
"Training: 70%|███████ | 1400/2000 [08:41<03:32, 2.83it/s, loss=2.6372, acc=0.490, ppl=14.0, lr=9.62e-03]\u001b[A\n",
"Training: 70%|███████ | 1410/2000 [08:44<03:23, 2.90it/s, loss=2.6372, acc=0.490, ppl=14.0, lr=9.62e-03]\u001b[A\n",
"Training: 70%|███████ | 1410/2000 [08:44<03:23, 2.90it/s, loss=2.4977, acc=0.499, ppl=12.2, lr=9.61e-03]\u001b[A\n",
"Training: 71%|███████ | 1420/2000 [08:48<03:25, 2.83it/s, loss=2.4977, acc=0.499, ppl=12.2, lr=9.61e-03]\u001b[A\n",
"Training: 71%|███████ | 1420/2000 [08:48<03:25, 2.83it/s, loss=2.6104, acc=0.480, ppl=13.6, lr=9.61e-03]\u001b[A\n",
"Training: 72%|███████▏ | 1430/2000 [08:51<03:16, 2.90it/s, loss=2.6104, acc=0.480, ppl=13.6, lr=9.61e-03]\u001b[A\n",
"Training: 72%|███████▏ | 1430/2000 [08:51<03:16, 2.90it/s, loss=2.5438, acc=0.489, ppl=12.7, lr=9.60e-03]\u001b[A\n",
"Training: 72%|███████▏ | 1440/2000 [08:55<03:18, 2.83it/s, loss=2.5438, acc=0.489, ppl=12.7, lr=9.60e-03]\u001b[A\n",
"Training: 72%|███████▏ | 1440/2000 [08:55<03:18, 2.83it/s, loss=2.5454, acc=0.493, ppl=12.7, lr=9.59e-03]\u001b[A\n",
"Training: 72%|███████▎ | 1450/2000 [08:58<03:09, 2.90it/s, loss=2.5454, acc=0.493, ppl=12.7, lr=9.59e-03]\u001b[A\n",
"Training: 72%|███████▎ | 1450/2000 [08:58<03:09, 2.90it/s, loss=2.5774, acc=0.483, ppl=13.2, lr=9.58e-03]\u001b[A\n",
"Training: 73%|███████▎ | 1460/2000 [09:02<03:10, 2.83it/s, loss=2.5774, acc=0.483, ppl=13.2, lr=9.58e-03]\u001b[A\n",
"Training: 73%|███████▎ | 1460/2000 [09:02<03:10, 2.83it/s, loss=2.5923, acc=0.485, ppl=13.4, lr=9.57e-03]\u001b[A\n",
"Training: 74%|███████▎ | 1470/2000 [09:05<03:03, 2.90it/s, loss=2.5923, acc=0.485, ppl=13.4, lr=9.57e-03]\u001b[A\n",
"Training: 74%|███████▎ | 1470/2000 [09:05<03:03, 2.90it/s, loss=2.6171, acc=0.467, ppl=13.7, lr=9.57e-03]\u001b[A\n",
"Training: 74%|███████▍ | 1480/2000 [09:09<03:03, 2.83it/s, loss=2.6171, acc=0.467, ppl=13.7, lr=9.57e-03]\u001b[A\n",
"Training: 74%|███████▍ | 1480/2000 [09:09<03:03, 2.83it/s, loss=2.4034, acc=0.513, ppl=11.1, lr=9.56e-03]\u001b[A\n",
"Training: 74%|███████▍ | 1490/2000 [09:12<02:56, 2.90it/s, loss=2.4034, acc=0.513, ppl=11.1, lr=9.56e-03]\u001b[A\n",
"Training: 74%|███████▍ | 1490/2000 [09:12<02:56, 2.90it/s, loss=2.5670, acc=0.490, ppl=13.0, lr=9.55e-03]\u001b[A\n",
"Training: 75%|███████▌ | 1500/2000 [09:16<02:56, 2.83it/s, loss=2.5670, acc=0.490, ppl=13.0, lr=9.55e-03]\u001b[A\n",
"Training: 75%|███████▌ | 1500/2000 [09:16<02:56, 2.83it/s, loss=2.5133, acc=0.495, ppl=12.3, lr=9.54e-03]\u001b[A"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Step 1500: Val Loss: 1.8981, Val Acc: 0.6077, Val PPL: 6.67\n",
"💾 Saved best model with val_loss: 1.8981\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n",
"Training: 76%|███████▌ | 1510/2000 [09:32<05:57, 1.37it/s, loss=2.5133, acc=0.495, ppl=12.3, lr=9.54e-03]\u001b[A\n",
"Training: 76%|███████▌ | 1510/2000 [09:32<05:57, 1.37it/s, loss=2.2680, acc=0.530, ppl=9.7, lr=9.54e-03] \u001b[A\n",
"Training: 76%|███████▌ | 1520/2000 [09:35<04:58, 1.61it/s, loss=2.2680, acc=0.530, ppl=9.7, lr=9.54e-03]\u001b[A\n",
"Training: 76%|███████▌ | 1520/2000 [09:36<04:58, 1.61it/s, loss=2.3578, acc=0.517, ppl=10.6, lr=9.53e-03]\u001b[A\n",
"Training: 76%|███████▋ | 1530/2000 [09:39<04:10, 1.88it/s, loss=2.3578, acc=0.517, ppl=10.6, lr=9.53e-03]\u001b[A\n",
"Training: 76%|███████▋ | 1530/2000 [09:39<04:10, 1.88it/s, loss=2.4861, acc=0.496, ppl=12.0, lr=9.52e-03]\u001b[A\n",
"Training: 77%|███████▋ | 1540/2000 [09:42<03:43, 2.06it/s, loss=2.4861, acc=0.496, ppl=12.0, lr=9.52e-03]\u001b[A\n",
"Training: 77%|███████▋ | 1540/2000 [09:43<03:43, 2.06it/s, loss=2.3362, acc=0.515, ppl=10.3, lr=9.51e-03]\u001b[A\n",
"Training: 78%|███████▊ | 1550/2000 [09:46<03:16, 2.29it/s, loss=2.3362, acc=0.515, ppl=10.3, lr=9.51e-03]\u001b[A\n",
"Training: 78%|███████▊ | 1550/2000 [09:46<03:16, 2.29it/s, loss=2.4001, acc=0.508, ppl=11.0, lr=9.50e-03]\u001b[A\n",
"Training: 78%|███████▊ | 1560/2000 [09:49<03:04, 2.39it/s, loss=2.4001, acc=0.508, ppl=11.0, lr=9.50e-03]\u001b[A\n",
"Training: 78%|███████▊ | 1560/2000 [09:50<03:04, 2.39it/s, loss=2.2475, acc=0.536, ppl=9.5, lr=9.49e-03] \u001b[A\n",
"Training: 78%|███████▊ | 1570/2000 [09:53<02:47, 2.56it/s, loss=2.2475, acc=0.536, ppl=9.5, lr=9.49e-03]\u001b[A\n",
"Training: 78%|███████▊ | 1570/2000 [09:53<02:47, 2.56it/s, loss=2.3208, acc=0.520, ppl=10.2, lr=9.49e-03]\u001b[A\n",
"Training: 79%|███████▉ | 1580/2000 [09:56<02:41, 2.60it/s, loss=2.3208, acc=0.520, ppl=10.2, lr=9.49e-03]\u001b[A\n",
"Training: 79%|███████▉ | 1580/2000 [09:57<02:41, 2.60it/s, loss=2.3816, acc=0.510, ppl=10.8, lr=9.48e-03]\u001b[A\n",
"Training: 80%|███████▉ | 1590/2000 [10:00<02:30, 2.72it/s, loss=2.3816, acc=0.510, ppl=10.8, lr=9.48e-03]\u001b[A\n",
"Training: 80%|███████▉ | 1590/2000 [10:00<02:30, 2.72it/s, loss=2.2568, acc=0.540, ppl=9.6, lr=9.47e-03] \u001b[A\n",
"Training: 80%|████████ | 1600/2000 [10:03<02:27, 2.71it/s, loss=2.2568, acc=0.540, ppl=9.6, lr=9.47e-03]\u001b[A\n",
"Training: 80%|████████ | 1600/2000 [10:04<02:27, 2.71it/s, loss=2.3356, acc=0.519, ppl=10.3, lr=9.46e-03]\u001b[A\n",
"Training: 80%|████████ | 1610/2000 [10:07<02:18, 2.81it/s, loss=2.3356, acc=0.519, ppl=10.3, lr=9.46e-03]\u001b[A\n",
"Training: 80%|████████ | 1610/2000 [10:07<02:18, 2.81it/s, loss=2.2585, acc=0.532, ppl=9.6, lr=9.45e-03] \u001b[A\n",
"Training: 81%|████████ | 1620/2000 [10:10<02:17, 2.77it/s, loss=2.2585, acc=0.532, ppl=9.6, lr=9.45e-03]\u001b[A\n",
"Training: 81%|████████ | 1620/2000 [10:11<02:17, 2.77it/s, loss=2.4129, acc=0.504, ppl=11.2, lr=9.44e-03]\u001b[A\n",
"Training: 82%|████████▏ | 1630/2000 [10:14<02:09, 2.85it/s, loss=2.4129, acc=0.504, ppl=11.2, lr=9.44e-03]\u001b[A\n",
"Training: 82%|████████▏ | 1630/2000 [10:14<02:09, 2.85it/s, loss=2.2562, acc=0.524, ppl=9.5, lr=9.43e-03] \u001b[A\n",
"Training: 82%|████████▏ | 1640/2000 [10:17<02:08, 2.80it/s, loss=2.2562, acc=0.524, ppl=9.5, lr=9.43e-03]\u001b[A\n",
"Training: 82%|████████▏ | 1640/2000 [10:18<02:08, 2.80it/s, loss=2.2399, acc=0.536, ppl=9.4, lr=9.42e-03]\u001b[A\n",
"Training: 82%|████████▎ | 1650/2000 [10:21<02:01, 2.87it/s, loss=2.2399, acc=0.536, ppl=9.4, lr=9.42e-03]\u001b[A\n",
"Training: 82%|████████▎ | 1650/2000 [10:21<02:01, 2.87it/s, loss=2.1606, acc=0.552, ppl=8.7, lr=9.41e-03]\u001b[A\n",
"Training: 83%|████████▎ | 1660/2000 [10:24<02:00, 2.81it/s, loss=2.1606, acc=0.552, ppl=8.7, lr=9.41e-03]\u001b[A\n",
"Training: 83%|████████▎ | 1660/2000 [10:25<02:00, 2.81it/s, loss=2.1323, acc=0.554, ppl=8.4, lr=9.40e-03]\u001b[A\n",
"Training: 84%|████████▎ | 1670/2000 [10:28<01:54, 2.88it/s, loss=2.1323, acc=0.554, ppl=8.4, lr=9.40e-03]\u001b[A\n",
"Training: 84%|████████▎ | 1670/2000 [10:28<01:54, 2.88it/s, loss=2.2089, acc=0.534, ppl=9.1, lr=9.40e-03]\u001b[A\n",
"Training: 84%|████████▍ | 1680/2000 [10:31<01:53, 2.82it/s, loss=2.2089, acc=0.534, ppl=9.1, lr=9.40e-03]\u001b[A\n",
"Training: 84%|████████▍ | 1680/2000 [10:32<01:53, 2.82it/s, loss=2.2999, acc=0.521, ppl=10.0, lr=9.38e-03]\u001b[A\n",
"Training: 84%|████████▍ | 1690/2000 [10:35<01:47, 2.89it/s, loss=2.2999, acc=0.521, ppl=10.0, lr=9.38e-03]\u001b[A\n",
"Training: 84%|████████▍ | 1690/2000 [10:35<01:47, 2.89it/s, loss=2.2390, acc=0.537, ppl=9.4, lr=9.38e-03] \u001b[A\n",
"Training: 85%|████████▌ | 1700/2000 [10:38<01:46, 2.82it/s, loss=2.2390, acc=0.537, ppl=9.4, lr=9.38e-03]\u001b[A\n",
"Training: 85%|████████▌ | 1700/2000 [10:39<01:46, 2.82it/s, loss=2.2410, acc=0.531, ppl=9.4, lr=9.37e-03]\u001b[A\n",
"Training: 86%|████████▌ | 1710/2000 [10:42<01:40, 2.89it/s, loss=2.2410, acc=0.531, ppl=9.4, lr=9.37e-03]\u001b[A\n",
"Training: 86%|████████▌ | 1710/2000 [10:42<01:40, 2.89it/s, loss=2.1824, acc=0.546, ppl=8.9, lr=9.36e-03]\u001b[A\n",
"Training: 86%|████████▌ | 1720/2000 [10:45<01:39, 2.82it/s, loss=2.1824, acc=0.546, ppl=8.9, lr=9.36e-03]\u001b[A\n",
"Training: 86%|████████▌ | 1720/2000 [10:46<01:39, 2.82it/s, loss=2.1889, acc=0.546, ppl=8.9, lr=9.35e-03]\u001b[A\n",
"Training: 86%|████████▋ | 1730/2000 [10:49<01:33, 2.90it/s, loss=2.1889, acc=0.546, ppl=8.9, lr=9.35e-03]\u001b[A\n",
"Training: 86%|████████▋ | 1730/2000 [10:49<01:33, 2.90it/s, loss=2.0367, acc=0.575, ppl=7.7, lr=9.34e-03]\u001b[A\n",
"Training: 87%|████████▋ | 1740/2000 [10:52<01:31, 2.83it/s, loss=2.0367, acc=0.575, ppl=7.7, lr=9.34e-03]\u001b[A\n",
"Training: 87%|████████▋ | 1740/2000 [10:53<01:31, 2.83it/s, loss=2.1979, acc=0.535, ppl=9.0, lr=9.33e-03]\u001b[A\n",
"Training: 88%|████████▊ | 1750/2000 [10:56<01:26, 2.90it/s, loss=2.1979, acc=0.535, ppl=9.0, lr=9.33e-03]\u001b[A\n",
"Training: 88%|████████▊ | 1750/2000 [10:56<01:26, 2.90it/s, loss=1.9579, acc=0.585, ppl=7.1, lr=9.32e-03]\u001b[A\n",
"Training: 88%|████████▊ | 1760/2000 [10:59<01:24, 2.83it/s, loss=1.9579, acc=0.585, ppl=7.1, lr=9.32e-03]\u001b[A\n",
"Training: 88%|████████▊ | 1760/2000 [11:00<01:24, 2.83it/s, loss=2.1100, acc=0.549, ppl=8.2, lr=9.31e-03]\u001b[A\n",
"Training: 88%|████████▊ | 1770/2000 [11:03<01:19, 2.90it/s, loss=2.1100, acc=0.549, ppl=8.2, lr=9.31e-03]\u001b[A\n",
"Training: 88%|████████▊ | 1770/2000 [11:03<01:19, 2.90it/s, loss=1.9703, acc=0.582, ppl=7.2, lr=9.30e-03]\u001b[A\n",
"Training: 89%|████████▉ | 1780/2000 [11:06<01:17, 2.83it/s, loss=1.9703, acc=0.582, ppl=7.2, lr=9.30e-03]\u001b[A\n",
"Training: 89%|████████▉ | 1780/2000 [11:07<01:17, 2.83it/s, loss=1.9572, acc=0.583, ppl=7.1, lr=9.29e-03]\u001b[A\n",
"Training: 90%|████████▉ | 1790/2000 [11:10<01:12, 2.90it/s, loss=1.9572, acc=0.583, ppl=7.1, lr=9.29e-03]\u001b[A\n",
"Training: 90%|████████▉ | 1790/2000 [11:10<01:12, 2.90it/s, loss=2.0601, acc=0.559, ppl=7.8, lr=9.28e-03]\u001b[A\n",
"Training: 90%|█████████ | 1800/2000 [11:13<01:10, 2.83it/s, loss=2.0601, acc=0.559, ppl=7.8, lr=9.28e-03]\u001b[A\n",
"Training: 90%|█████████ | 1800/2000 [11:14<01:10, 2.83it/s, loss=1.8826, acc=0.596, ppl=6.6, lr=9.27e-03]\u001b[A\n",
"Training: 90%|█████████ | 1810/2000 [11:17<01:05, 2.90it/s, loss=1.8826, acc=0.596, ppl=6.6, lr=9.27e-03]\u001b[A\n",
"Training: 90%|█████████ | 1810/2000 [11:17<01:05, 2.90it/s, loss=2.0799, acc=0.556, ppl=8.0, lr=9.26e-03]\u001b[A\n",
"Training: 91%|█████████ | 1820/2000 [11:20<01:03, 2.83it/s, loss=2.0799, acc=0.556, ppl=8.0, lr=9.26e-03]\u001b[A\n",
"Training: 91%|█████████ | 1820/2000 [11:21<01:03, 2.83it/s, loss=1.9951, acc=0.570, ppl=7.4, lr=9.25e-03]\u001b[A\n",
"Training: 92%|█████████▏| 1830/2000 [11:23<00:58, 2.90it/s, loss=1.9951, acc=0.570, ppl=7.4, lr=9.25e-03]\u001b[A\n",
"Training: 92%|█████████▏| 1830/2000 [11:24<00:58, 2.90it/s, loss=1.8708, acc=0.596, ppl=6.5, lr=9.24e-03]\u001b[A\n",
"Training: 92%|█████████▏| 1840/2000 [11:27<00:56, 2.83it/s, loss=1.8708, acc=0.596, ppl=6.5, lr=9.24e-03]\u001b[A\n",
"Training: 92%|█████████▏| 1840/2000 [11:28<00:56, 2.83it/s, loss=1.9393, acc=0.585, ppl=7.0, lr=9.23e-03]\u001b[A\n",
"Training: 92%|█████████▎| 1850/2000 [11:30<00:51, 2.90it/s, loss=1.9393, acc=0.585, ppl=7.0, lr=9.23e-03]\u001b[A\n",
"Training: 92%|█████████▎| 1850/2000 [11:31<00:51, 2.90it/s, loss=2.1097, acc=0.556, ppl=8.2, lr=9.22e-03]\u001b[A\n",
"Training: 93%|█████████▎| 1860/2000 [11:34<00:49, 2.83it/s, loss=2.1097, acc=0.556, ppl=8.2, lr=9.22e-03]\u001b[A\n",
"Training: 93%|█████████▎| 1860/2000 [11:35<00:49, 2.83it/s, loss=1.9478, acc=0.588, ppl=7.0, lr=9.21e-03]\u001b[A\n",
"Training: 94%|█████████▎| 1870/2000 [11:37<00:44, 2.90it/s, loss=1.9478, acc=0.588, ppl=7.0, lr=9.21e-03]\u001b[A\n",
"Training: 94%|█████████▎| 1870/2000 [11:38<00:44, 2.90it/s, loss=2.0532, acc=0.566, ppl=7.8, lr=9.20e-03]\u001b[A\n",
"Training: 94%|█████████▍| 1880/2000 [11:41<00:42, 2.83it/s, loss=2.0532, acc=0.566, ppl=7.8, lr=9.20e-03]\u001b[A\n",
"Training: 94%|█████████▍| 1880/2000 [11:42<00:42, 2.83it/s, loss=1.8404, acc=0.597, ppl=6.3, lr=9.18e-03]\u001b[A\n",
"Training: 94%|█████████▍| 1890/2000 [11:45<00:38, 2.87it/s, loss=1.8404, acc=0.597, ppl=6.3, lr=9.18e-03]\u001b[A\n",
"Training: 94%|█████████▍| 1890/2000 [11:45<00:38, 2.87it/s, loss=1.9833, acc=0.567, ppl=7.3, lr=9.18e-03]\u001b[A\n",
"Training: 95%|█████████▌| 1900/2000 [11:48<00:35, 2.81it/s, loss=1.9833, acc=0.567, ppl=7.3, lr=9.18e-03]\u001b[A\n",
"Training: 95%|█████████▌| 1900/2000 [11:49<00:35, 2.81it/s, loss=1.9771, acc=0.579, ppl=7.2, lr=9.16e-03]\u001b[A\n",
"Training: 96%|█████████▌| 1910/2000 [11:52<00:31, 2.88it/s, loss=1.9771, acc=0.579, ppl=7.2, lr=9.16e-03]\u001b[A\n",
"Training: 96%|█████████▌| 1910/2000 [11:52<00:31, 2.88it/s, loss=1.8308, acc=0.601, ppl=6.2, lr=9.15e-03]\u001b[A\n",
"Training: 96%|█████████▌| 1920/2000 [11:55<00:28, 2.82it/s, loss=1.8308, acc=0.601, ppl=6.2, lr=9.15e-03]\u001b[A\n",
"Training: 96%|█████████▌| 1920/2000 [11:56<00:28, 2.82it/s, loss=1.7965, acc=0.604, ppl=6.0, lr=9.14e-03]\u001b[A\n",
"Training: 96%|█████████▋| 1930/2000 [11:59<00:24, 2.89it/s, loss=1.7965, acc=0.604, ppl=6.0, lr=9.14e-03]\u001b[A\n",
"Training: 96%|█████████▋| 1930/2000 [11:59<00:24, 2.89it/s, loss=1.8464, acc=0.593, ppl=6.3, lr=9.13e-03]\u001b[A\n",
"Training: 97%|█████████▋| 1940/2000 [12:02<00:21, 2.83it/s, loss=1.8464, acc=0.593, ppl=6.3, lr=9.13e-03]\u001b[A\n",
"Training: 97%|█████████▋| 1940/2000 [12:03<00:21, 2.83it/s, loss=1.7797, acc=0.608, ppl=5.9, lr=9.12e-03]\u001b[A\n",
"Training: 98%|█████████▊| 1950/2000 [12:05<00:17, 2.90it/s, loss=1.7797, acc=0.608, ppl=5.9, lr=9.12e-03]\u001b[A\n",
"Training: 98%|█████████▊| 1950/2000 [12:06<00:17, 2.90it/s, loss=1.8386, acc=0.599, ppl=6.3, lr=9.11e-03]\u001b[A\n",
"Training: 98%|█████████▊| 1960/2000 [12:09<00:14, 2.83it/s, loss=1.8386, acc=0.599, ppl=6.3, lr=9.11e-03]\u001b[A\n",
"Training: 98%|█████████▊| 1960/2000 [12:10<00:14, 2.83it/s, loss=1.7092, acc=0.618, ppl=5.5, lr=9.10e-03]\u001b[A\n",
"Training: 98%|█████████▊| 1970/2000 [12:12<00:10, 2.90it/s, loss=1.7092, acc=0.618, ppl=5.5, lr=9.10e-03]\u001b[A\n",
"Training: 98%|█████████▊| 1970/2000 [12:13<00:10, 2.90it/s, loss=1.8129, acc=0.598, ppl=6.1, lr=9.09e-03]\u001b[A\n",
"Training: 99%|█████████▉| 1980/2000 [12:16<00:07, 2.83it/s, loss=1.8129, acc=0.598, ppl=6.1, lr=9.09e-03]\u001b[A\n",
"Training: 99%|█████████▉| 1980/2000 [12:17<00:07, 2.83it/s, loss=1.8092, acc=0.600, ppl=6.1, lr=9.07e-03]\u001b[A\n",
"Training: 100%|█████████▉| 1990/2000 [12:19<00:03, 2.90it/s, loss=1.8092, acc=0.600, ppl=6.1, lr=9.07e-03]\u001b[A\n",
"Training: 100%|█████████▉| 1990/2000 [12:20<00:03, 2.90it/s, loss=1.8079, acc=0.604, ppl=6.1, lr=9.06e-03]\u001b[A\n",
"Training: 100%|██████████| 2000/2000 [12:23<00:00, 2.69it/s, loss=1.8079, acc=0.604, ppl=6.1, lr=9.06e-03]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
" ⏱️ Training completed in 743.7 seconds\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
" 📊 Final - Loss: 1.1146, Acc: 0.7512, PPL: 3.05\n",
"💾 Saved final model to final_model.pt\n",
"\n",
"🎉 TRAINING COMPLETED!\n",
"⏱️ Total time: 12.6 minutes\n",
"🏆 Final Results:\n",
" Validation Loss: 1.1146\n",
" Validation Accuracy: 0.7512\n",
" Validation Perplexity: 3.05\n"
]
}
],
"source": [
"if __name__ == \"__main__\":\n",
" # Check system\n",
" print(f\"🔍 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}\")\n",
" if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name()}\")\n",
" print(f\"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
"\n",
" # Set seed\n",
" set_seed(42)\n",
"\n",
" # Create config for Small model\n",
" config = ModelConfig()\n",
" print(f\"\\n📋 Model Configuration:\")\n",
" print(f\" Architecture: {config.d_model}d, {config.n_layers}L, {config.n_heads}H, {config.d_ff}ff\")\n",
" print(f\" Training: {config.max_steps} steps, batch size {config.batch_size}\")\n",
" print(f\" Data: {config.max_tokens:,} tokens, seq_len {config.max_seq_len}\")\n",
"\n",
" # Load data\n",
" texts, tokenizer, tokens = load_and_cache_data(config)\n",
" dataset = TextTokenDataset(tokens, config.max_seq_len)\n",
"\n",
" # Train/val split\n",
" val_size = len(dataset) // 10\n",
" train_size = len(dataset) - val_size\n",
" train_dataset, val_dataset = torch.utils.data.random_split(\n",
" dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)\n",
" )\n",
"\n",
" train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)\n",
" val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2)\n",
"\n",
" print(f\"📊 Dataset: {len(train_dataset)} train, {len(val_dataset)} val samples\")\n",
"\n",
" # Train model\n",
" start_time = time.time()\n",
" model, final_metrics = train_model(config, train_loader, val_loader)\n",
" total_time = time.time() - start_time\n",
"\n",
" print(f\"\\n🎉 TRAINING COMPLETED!\")\n",
" print(f\"⏱️ Total time: {total_time/60:.1f} minutes\")\n",
" print(f\"🏆 Final Results:\")\n",
" print(f\" Validation Loss: {final_metrics['val_loss']:.4f}\")\n",
" print(f\" Validation Accuracy: {final_metrics['val_accuracy']:.4f}\")\n",
" print(f\" Validation Perplexity: {final_metrics['val_perplexity']:.2f}\")"
]
},
{
"cell_type": "markdown",
"source": [
"## 17. Model Loading and Inference\n",
"\n",
"After training, we can load our saved model and use it for text generation.\n",
"This section shows how to load the trained model and perform inference."
],
"metadata": {
"id": "mcHNleswBtA0"
}
},
{
"cell_type": "code",
"source": [
"def load_trained_model(model_path: str = \"final_model.pt\"):\n",
" \"\"\"Load a trained model from checkpoint\"\"\"\n",
" print(f\" Loading model from {model_path}\")\n",
"\n",
" # Add ModelConfig to safe globals for PyTorch 2.6+\n",
" from torch.serialization import add_safe_globals\n",
" add_safe_globals([ModelConfig])\n",
"\n",
" try:\n",
" checkpoint = torch.load(model_path, map_location='cpu')\n",
" config = checkpoint['config']\n",
" except Exception as e:\n",
" print(f\"⚠️ Error loading with weights_only=True, trying with weights_only=False...\")\n",
" checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)\n",
" config = checkpoint['config']\n",
"\n",
" # Create model with same config\n",
" model = MinimalLLM(config)\n",
" model.load_state_dict(checkpoint['model_state_dict'])\n",
"\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" model = model.to(device)\n",
" model.eval()\n",
"\n",
" print(f\"✅ Model loaded successfully\")\n",
" print(f\" Parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
" print(f\" Device: {device}\")\n",
"\n",
" return model, config"
],
"metadata": {
"id": "jkJxa1P6Bz2e"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 🔮 Text Generation Function\n",
"\n",
"This function generates text using a trained language model. Given a prompt, it tokenizes the input and autoregressively samples tokens up to `max_length`. It supports:\n",
"\n",
"- **Temperature scaling** for randomness control, e.g., `0.7` makes output more focused, `1.5` makes it more random.\n",
"- **Top-k sampling** to limit candidates to the top `k` most likely tokens, e.g., `top_k=50` narrow down to 50 highest-probability tokens.\n",
"- **Top-p (nucleus) sampling** to sample from the smallest set of tokens whose cumulative probability exceeds `p`, i.e. the fewest number of tokens whose combined probabilities add up to at least p (e.g., 90%).\n",
"\n",
"Generation stops early if the EOS token is produced.\n"
],
"metadata": {
"id": "UoFBxnhXCNPD"
}
},
{
"cell_type": "code",
"source": [
"def generate_text(model: nn.Module, tokenizer, prompt: str, max_length: int = 100,\n",
" temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):\n",
" \"\"\"Generate text using the trained model\"\"\"\n",
" model.eval()\n",
" device = next(model.parameters()).device\n",
"\n",
" # Tokenize prompt\n",
" input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt').to(device)\n",
"\n",
" generated_ids = input_ids.clone()\n",
"\n",
" with torch.no_grad():\n",
" for _ in range(max_length):\n",
" # Get model predictions\n",
" logits = model(generated_ids)\n",
" next_token_logits = logits[0, -1, :] / temperature\n",
"\n",
" # Apply top-k filtering\n",
" if top_k > 0:\n",
" top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)\n",
" next_token_logits = torch.full_like(next_token_logits, float('-inf'))\n",
" next_token_logits[top_k_indices] = top_k_logits\n",
"\n",
" # Apply top-p (nucleus) filtering\n",
" if top_p < 1.0:\n",
" sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)\n",
" cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n",
" sorted_indices_to_remove = cumulative_probs > top_p\n",
" sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n",
" sorted_indices_to_remove[0] = 0\n",
" indices_to_remove = sorted_indices[sorted_indices_to_remove]\n",
" next_token_logits[indices_to_remove] = float('-inf')\n",
"\n",
" # Sample next token\n",
" probs = F.softmax(next_token_logits, dim=-1)\n",
" next_token = torch.multinomial(probs, num_samples=1)\n",
"\n",
" # Append to generated sequence - FIX: ensure same dimensions\n",
" next_token = next_token.unsqueeze(0) # Add batch dimension\n",
" generated_ids = torch.cat([generated_ids, next_token], dim=1)\n",
"\n",
" # Stop if we reach the end token\n",
" if next_token.item() == tokenizer.eos_token_id:\n",
" break\n",
"\n",
" # Decode the generated text\n",
" generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
" return generated_text"
],
"metadata": {
"id": "SPJIgtbIB0_W"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def interactive_inference(model_path: str = \"final_model.pt\"):\n",
" \"\"\"Interactive inference session\"\"\"\n",
" print(\"🤖 Starting interactive inference session\")\n",
" print(\"Type 'quit' to exit\")\n",
"\n",
" # Load model and tokenizer\n",
" model, config = load_trained_model(model_path)\n",
"\n",
" # Load tokenizer (assuming we have the same one used during training)\n",
" tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/SmolLM-135M\")\n",
" if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
" while True:\n",
" try:\n",
" prompt = input(\"\\n Enter your prompt: \")\n",
" if prompt.lower() in ['quit', 'exit', 'q']:\n",
" print(\"👋 Goodbye!\")\n",
" break\n",
"\n",
" if not prompt.strip():\n",
" continue\n",
"\n",
" print(\"🔄 Generating...\")\n",
" generated_text = generate_text(\n",
" model, tokenizer, prompt,\n",
" max_length=150,\n",
" temperature=0.8,\n",
" top_k=50,\n",
" top_p=0.9\n",
" )\n",
"\n",
" print(f\"\\n Generated text:\")\n",
" print(f\"📝 {generated_text}\")\n",
"\n",
" except KeyboardInterrupt:\n",
" print(\"\\n👋 Goodbye!\")\n",
" break\n",
" except Exception as e:\n",
" print(f\"❌ Error: {e}\")"
],
"metadata": {
"id": "TIxiUUKbB3F2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def demo_inference(model_path: str = \"final_model.pt\"):\n",
" \"\"\"Run a quick demo of the model's capabilities\"\"\"\n",
" print(\"🎭 Running inference demo\")\n",
"\n",
" # Load model and tokenizer\n",
" model, config = load_trained_model(model_path)\n",
" tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/SmolLM-135M\")\n",
" if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
" # Demo prompts\n",
" demo_prompts = [\n",
" \"The future of artificial intelligence\",\n",
" \"Once upon a time in a distant galaxy\",\n",
" \"The most important thing to remember is\",\n",
" \"In the year 2050, technology will\",\n",
" \"The best way to learn programming is\"\n",
" ]\n",
"\n",
" for i, prompt in enumerate(demo_prompts, 1):\n",
" print(f\"\\n Demo {i}: '{prompt}'\")\n",
" print(\"-\" * 50)\n",
"\n",
" generated_text = generate_text(\n",
" model, tokenizer, prompt,\n",
" max_length=100,\n",
" temperature=0.7,\n",
" top_k=40,\n",
" top_p=0.85\n",
" )\n",
"\n",
" print(f\"📝 {generated_text}\")\n",
" print()"
],
"metadata": {
"id": "F57LSWDkB44U"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"if __name__ == \"__main__\":\n",
" # Check if we have a trained model\n",
" import os\n",
"\n",
" if os.path.exists(\"final_model.pt\"):\n",
" print(\"🎉 Found trained model! Running demo...\")\n",
" demo_inference(\"final_model.pt\")\n",
"\n",
" # Optionally run interactive session\n",
" response = input(\"\\n🤖 Would you like to try interactive inference? (y/n): \")\n",
" if response.lower() in ['y', 'yes']:\n",
" interactive_inference(\"final_model.pt\")\n",
" else:\n",
" print(\"⚠️ No trained model found. Please run the training cells first.\")\n",
" print(\"💡 Look for 'final_model.pt' or 'best_model.pt' in your directory.\")"
],
"metadata": {
"id": "oTTx359KB6W_",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "599c1f86-1720-4c03-c1a5-6cca5649e3e8"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"🎉 Found trained model! Running demo...\n",
"🎭 Running inference demo\n",
" Loading model from final_model.pt\n",
"✅ Model loaded successfully\n",
" Parameters: 32,150,976\n",
" Device: cuda\n",
"\n",
" Demo 1: 'The future of artificial intelligence'\n",
"--------------------------------------------------\n",
"📝 The future of artificial intelligence, which deals with different variables. For instance, if a case might sound sound sound like the power of data needs to different variables. Similarly, these algorithms often used to generate data needs and problem-quality data.\n",
"\n",
"Another crucial aspect of data is designing data. For example, QoS ensures no single points like a QFD and video packets detailing how much power instead of information. How do they feel like a reward system creates a class, if you could lead to a solution for everyone can\n",
"\n",
"\n",
" Demo 2: 'Once upon a time in a distant galaxy'\n",
"--------------------------------------------------\n",
"📝 Once upon a time in a distant galaxy National Park, I was one day, and I was excited to be related to be fun but I was using a few years of the power of the Ottoman Empire.\n",
"The Power of the Ottoman Empire. I was indeed be considered the throne as I had always been shown to be taken quite the Ottoman Empire.\n",
"The Ottoman Empire. I decided to be be considered the Ottoman Empire. I was not have been signed the Ottoman Empire.\n",
"\n",
"The Ottoman Empire. I was indeed a line of the\n",
"\n",
"\n",
" Demo 3: 'The most important thing to remember is'\n",
"--------------------------------------------------\n",
"📝 The most important thing to remember is especially when you want to keep your own language.\n",
"\n",
"**How does CMC refers to a foreign language, and power! This means setting up. It teaches kids who learn about ways that everyone has access to different cultures, and sharing resources.\n",
"\n",
"**How Does Government:**\n",
"\n",
"Even though they could lead to better understand different religions can affect their interests, and ways. By doing so, we can work together.\n",
"\n",
"\n",
"Another important aspect of the globe. When encountering various forms\n",
"\n",
"\n",
" Demo 4: 'In the year 2050, technology will'\n",
"--------------------------------------------------\n",
"📝 In the year 2050, technology will be understood that every single event during recess when there was a few years of the United States. During this period, the Battle of New England, a British Empire had been cities like 17770s and were no 180s when Mary was still ruled by 191921700s when she was still divided into battle. Her mother whose son had been debate among all sorts of her career. She asked her mother and saw a renowned poet who\n",
"\n",
"\n",
" Demo 5: 'The best way to learn programming is'\n",
"--------------------------------------------------\n",
"📝 The best way to learn programming is especially when working with others.\n",
"\n",
"**Exploring Gourds**\n",
"\n",
"Now let's talk about something called \"font\" that you know what a \"font\" that that means that is. A font has a specific words and accessories, such as a way of a common design. It can also include letters, but can also affect a common forms, and shapes, objects.\n",
"\n",
"**Why is a character, characters that create a font has its core, and characters that can also include\n",
"\n",
"\n",
"🤖 Would you like to try interactive inference? (y/n): y\n",
"🤖 Starting interactive inference session\n",
"Type 'quit' to exit\n",
" Loading model from final_model.pt\n",
"✅ Model loaded successfully\n",
" Parameters: 32,150,976\n",
" Device: cuda\n",
"\n",
" Enter your prompt: the best LLM architecture is\n",
"🔄 Generating...\n",
"\n",
" Generated text:\n",
"📝 the best LLM architecture is essential in daily life. Today, we will learn about some fundamental concepts and how they are looking at work.\n",
"\n",
"Section 1: What Is Theatre?\n",
"Beyond practicing practicing hand at local laws are three real-1. **Bored**: The Impact on non-American Human Rights**\n",
"\n",
"To understand the Criminal Justice** is essential for justice. This involves studying physical activity is crucial for everyone, and respond to understand themselves reflected in shaping shaping their views.\n",
"\n",
"Now, let's discuss ways we understand what we do this fascinating world to understand what we mean by looking at home?\n",
"** is a *zone*. Mani refers to the country's likelihood of the poor rights, and the environment, including the environment. When someone\n",
"\n",
" Enter your prompt: today I will have a walk and\n",
"🔄 Generating...\n",
"\n",
" Generated text:\n",
"📝 today I will have a walk and discuss the following strict rules of the following command:\n",
"\n",
"### 1. Ik-Contra affair happened during the law rather than anything else, it might be made for those positions of authority to be considered for everyone.\n",
"\n",
"The invasion in 1980s. This situation led them apart the Soviet Union from the USSR; however, the EU EU EU EU EU EU EU EU.\n",
"The term 'Warriors Three.'s, they form one of the government—like a more than just like the free markets filled with the foreign lands, and skills. As they rely on attendees, they pass by their target program.\n",
"\n",
"### Counter Counterfactual History\n",
"There are several distinct category of English as special offer numerous\n",
"\n",
" Enter your prompt: The future of cars is\n",
"🔄 Generating...\n",
"\n",
" Generated text:\n",
"📝 The future of cars is a solution that helps us consider the previous experiences of our own way. Let's learn about how we can help us build things and learn about what it means.\n",
"\n",
"What does \"A good (like learning), I mean?\n",
"C)?\n",
"----------------------------------------------------\n",
"\n",
"1. **Health Information Technology:** If a new ones include things you want to go.\n",
"2. **Jins**: It helps us improve our body, just like a new friends, they are complex systems. They help us, you know what you want to build something new friends, making sure if one of the environment, kids may take care of living beings.\n",
"3. **Jazz** refers to the ground, ragtime, often used to the park. Both are\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "Ion2bUGNB-As"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"language_info": {
"name": "python"
},
"colab": {
"provenance": [],
"gpuType": "T4",
"collapsed_sections": [
"ZH1RhXtTSyvG"
],
"include_colab_link": true
},
"accelerator": "GPU",
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment