Last active
May 5, 2025 18:21
-
-
Save abpai/6c1dfdc45f77238528354e53fab3be6c to your computer and use it in GitHub Desktop.
twotower-infonce.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyPkQpqLKoUJ30YZ99M23YIA", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/abpai/6c1dfdc45f77238528354e53fab3be6c/twotower-infonce.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%pip install structlog" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"collapsed": true, | |
"id": "NV17kMksXa11", | |
"outputId": "c0eb24e9-168e-4f14-b5ab-f8c12d6d121c" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Requirement already satisfied: structlog in /usr/local/lib/python3.11/dist-packages (25.3.0)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 953 | |
}, | |
"id": "EbSrcQuCXPuB", | |
"outputId": "32175964-d487-46b8-b10c-5cd5a5a553db" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2025-05-05 18:14:05 [info ] Using device device=device(type='cpu')\n", | |
"2025-05-05 18:14:05 [info ] Building Synthetic Data\n", | |
"2025-05-05 18:14:05 [info ] Creating DataLoader\n", | |
"2025-05-05 18:14:05 [info ] Initializing Model and Optimizer\n", | |
"2025-05-05 18:14:05 [info ] Starting Training\n", | |
"2025-05-05 18:14:05 [info ] epoch_summary epoch=1 loss=2.4593790285289288 lr=0.0003 time=0.052422285079956055\n", | |
"2025-05-05 18:14:05 [info ] epoch_summary epoch=2 loss=2.3996245190501213 lr=0.0003 time=0.05826759338378906\n", | |
"2025-05-05 18:14:06 [info ] epoch_summary epoch=3 loss=2.3641757145524025 lr=0.0003 time=0.05289435386657715\n", | |
"2025-05-05 18:14:06 [info ] epoch_summary epoch=4 loss=2.3509834930300713 lr=0.0003 time=0.05392646789550781\n", | |
"2025-05-05 18:14:06 [info ] epoch_summary epoch=5 loss=2.335583135485649 lr=0.0003 time=0.058527231216430664\n", | |
"2025-05-05 18:14:06 [info ] epoch_summary epoch=6 loss=2.3280965611338615 lr=0.0003 time=0.05180931091308594\n", | |
"2025-05-05 18:14:06 [info ] epoch_summary epoch=7 loss=2.3198156468570232 lr=0.0003 time=0.0527191162109375\n", | |
"2025-05-05 18:14:06 [info ] epoch_summary epoch=8 loss=2.3186416160315275 lr=0.0003 time=0.05267667770385742\n", | |
"2025-05-05 18:14:06 [info ] epoch_summary epoch=9 loss=2.3187887631356716 lr=0.0003 time=0.05245041847229004\n", | |
"2025-05-05 18:14:06 [info ] epoch_summary epoch=10 loss=2.307461339980364 lr=0.0003 time=0.05702614784240723\n", | |
"2025-05-05 18:14:06 [info ] Training Finished total_training_time=0.5518066883087158\n", | |
"2025-05-05 18:14:06 [info ] Running Diagnostics\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 1200x500 with 2 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2025-05-05 18:14:06 [info ] \n", | |
"--- Calculating Recall@10 ---\n", | |
"2025-05-05 18:14:06 [info ] Encoding 500 documents...\n", | |
"2025-05-05 18:14:06 [info ] Document encoding took 0.00s\n", | |
"2025-05-05 18:14:06 [info ] Performing search for 500 queries...\n", | |
"2025-05-05 18:14:06 [info ] Search took 0.01s\n", | |
"2025-05-05 18:14:06 [info ] \n", | |
"Recall@10 (full corpus): 58.60%\n" | |
] | |
} | |
], | |
"source": [ | |
"\"\"\"\n", | |
"Minimal ranking demo, InfoNCE\n", | |
"--------------------\n", | |
"\n", | |
"• 500 toy queries and 500 toy docs\n", | |
"• Two-tower encoder: Embedding → mean-pool → Linear → L2-norm\n", | |
"• Trains with CrossEntropyLoss\n", | |
"• Evaluates Recall@k over the full corpus (1 doc / query)\n", | |
"\"\"\"\n", | |
"\n", | |
"import time\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"import structlog\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from torch.nn.functional import normalize\n", | |
"from torch.utils.data import DataLoader, Dataset\n", | |
"\n", | |
"logger = structlog.get_logger(__name__)\n", | |
"\n", | |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
"logger.info('Using device', device=device)\n", | |
"\n", | |
"\n", | |
"# ---------------------------------------------------------------------\n", | |
"# 1. Synthetic data (N queries == N docs == corpus size)\n", | |
"# ---------------------------------------------------------------------\n", | |
"def build_synthetic(\n", | |
" n_queries=500, vocab=100, q_len=4, doc_len=16, overlap=0.6, seed=42\n", | |
"):\n", | |
" rng = torch.Generator().manual_seed(seed)\n", | |
" queries = torch.randint(0, vocab, (n_queries, q_len), generator=rng)\n", | |
" docs_pos = torch.randint(0, vocab, (n_queries, doc_len), generator=rng)\n", | |
"\n", | |
" for i in range(n_queries):\n", | |
" num_copy = int(overlap * q_len)\n", | |
" copy_idx = torch.randperm(q_len, generator=rng)[:num_copy]\n", | |
" docs_pos[i, :num_copy] = queries[i, copy_idx]\n", | |
"\n", | |
" return queries, docs_pos\n", | |
"\n", | |
"\n", | |
"# ---------------------------------------------------------------------\n", | |
"# 2. Dataset (triples: q | pos | neg)\n", | |
"# ---------------------------------------------------------------------\n", | |
"class TripleDS(Dataset):\n", | |
" def __init__(self, q, p):\n", | |
" self.q, self.p = q, p\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.q)\n", | |
"\n", | |
" def __getitem__(self, i):\n", | |
" return {'q': self.q[i], 'd_pos': self.p[i]}\n", | |
"\n", | |
"\n", | |
"def collate(batch):\n", | |
" result = {}\n", | |
" for k in batch[0]:\n", | |
" result[k] = torch.stack([item[k] for item in batch]).to(device)\n", | |
" return result\n", | |
"\n", | |
"\n", | |
"# ---------------------------------------------------------------------\n", | |
"# 3. Two-tower model\n", | |
"# ---------------------------------------------------------------------\n", | |
"class TwoTower(nn.Module):\n", | |
" def __init__(self, vocab, emb_dim=36, proj_dim=18):\n", | |
" super().__init__()\n", | |
" self.embedding = nn.Embedding(vocab, emb_dim)\n", | |
" self.proj = nn.Linear(emb_dim, proj_dim, bias=False)\n", | |
"\n", | |
" def encode(self, toks):\n", | |
" # Ensure input tensors are on the correct device within the model\n", | |
" toks = toks.to(next(self.parameters()).device)\n", | |
" x = self.embedding(toks).mean(dim=1) # (B, emb_dim), mean-pooling\n", | |
" return normalize(self.proj(x), dim=-1) # (B, proj_dim) ‖·‖₂ = 1\n", | |
"\n", | |
" def forward(self, q, d):\n", | |
" qv, dv = self.encode(q), self.encode(d)\n", | |
" return (qv * dv).sum(dim=-1) # cosine similarity\n", | |
"\n", | |
"\n", | |
"# ---------------------------------------------------------------------\n", | |
"# 4. Train\n", | |
"# ---------------------------------------------------------------------\n", | |
"\n", | |
"# --- Hyperparameters ---\n", | |
"n_queries = 500\n", | |
"vocab = 50\n", | |
"q_len = 16\n", | |
"overlap = 0.8\n", | |
"seed = 1337\n", | |
"batch_size = 16\n", | |
"margin = 0.25\n", | |
"lr = 3e-4\n", | |
"epochs = 10\n", | |
"emb_dim = 48\n", | |
"proj_dim = 72\n", | |
"\n", | |
"logger.info('Building Synthetic Data')\n", | |
"queries, docs_pos = build_synthetic(n_queries, vocab, q_len, doc_len, overlap, seed)\n", | |
"\n", | |
"logger.info('Creating DataLoader')\n", | |
"train_dataset = TripleDS(queries, docs_pos)\n", | |
"loader = DataLoader(\n", | |
" train_dataset,\n", | |
" batch_size=batch_size,\n", | |
" shuffle=True,\n", | |
" collate_fn=collate, # to handle device placement\n", | |
")\n", | |
"\n", | |
"logger.info('Initializing Model and Optimizer')\n", | |
"model = TwoTower(vocab=vocab, emb_dim=emb_dim, proj_dim=proj_dim).to(device)\n", | |
"opt = torch.optim.AdamW(model.parameters(), lr=lr)\n", | |
"loss_fn = nn.CrossEntropyLoss()\n", | |
"\n", | |
"# ---> Lists to store metrics for plotting <---\n", | |
"loss_history = []\n", | |
"lr_history = []\n", | |
"epoch_times = []\n", | |
"\n", | |
"logger.info('Starting Training')\n", | |
"training_start_time = time.time()\n", | |
"\n", | |
"for epoch in range(epochs):\n", | |
" epoch_start_time = time.time()\n", | |
" model.train()\n", | |
" total_loss = 0.0\n", | |
"\n", | |
" current_lr = opt.param_groups[0]['lr']\n", | |
" lr_history.append(current_lr)\n", | |
"\n", | |
" for batch in loader:\n", | |
" q_vec = model.encode(batch['q'])\n", | |
" d_vec = model.encode(batch['d_pos'])\n", | |
"\n", | |
" logits = q_vec @ d_vec.T\n", | |
" labels = torch.arange(logits.size(0), device=logits.device)\n", | |
" loss = loss_fn(logits, labels)\n", | |
"\n", | |
" opt.zero_grad()\n", | |
" loss.backward()\n", | |
" opt.step()\n", | |
"\n", | |
" total_loss += loss.item()\n", | |
"\n", | |
" avg_loss = total_loss / len(loader)\n", | |
" loss_history.append(avg_loss)\n", | |
"\n", | |
" epoch_end_time = time.time()\n", | |
" epoch_duration = epoch_end_time - epoch_start_time\n", | |
" epoch_times.append(epoch_duration)\n", | |
"\n", | |
" logger.info(\n", | |
" 'epoch_summary',\n", | |
" epoch=epoch + 1,\n", | |
" loss=avg_loss,\n", | |
" lr=current_lr,\n", | |
" time=epoch_duration,\n", | |
" )\n", | |
"\n", | |
"training_end_time = time.time()\n", | |
"total_training_time = training_end_time - training_start_time\n", | |
"logger.info('Training Finished', total_training_time=total_training_time)\n", | |
"\n", | |
"# ---------------------------------------------------------------------\n", | |
"# 5. Quick diagnostics: accuracy + histograms + LR Plot\n", | |
"# ---------------------------------------------------------------------\n", | |
"logger.info('Running Diagnostics')\n", | |
"model.eval()\n", | |
"queries_dev = queries.to(device)\n", | |
"docs_pos_dev = docs_pos.to(device)\n", | |
"\n", | |
"# --- Plotting Section ---\n", | |
"plt.style.use('seaborn-v0_8-whitegrid')\n", | |
"\n", | |
"# Figure 1: Loss and Learning Rate\n", | |
"plt.figure(figsize=(12, 5))\n", | |
"\n", | |
"# Subplot 1: Loss Curve\n", | |
"plt.subplot(1, 2, 1)\n", | |
"plt.plot(range(1, epochs + 1), loss_history, marker='o', linestyle='-', color='b')\n", | |
"plt.title('Training Loss per Epoch')\n", | |
"plt.xlabel('Epoch')\n", | |
"plt.ylabel('Cross Entropy Loss')\n", | |
"plt.xticks(range(1, epochs + 1, max(1, epochs // 10))) # Adjust x-ticks for readability\n", | |
"plt.grid(True)\n", | |
"\n", | |
"# Subplot 2: Learning Rate Schedule\n", | |
"plt.subplot(1, 2, 2)\n", | |
"plt.plot(range(1, epochs + 1), lr_history, marker='.', linestyle='-', color='r')\n", | |
"plt.title('Learning Rate Schedule')\n", | |
"plt.xlabel('Epoch')\n", | |
"plt.ylabel('Learning Rate')\n", | |
"plt.xticks(range(1, epochs + 1, max(1, epochs // 10)))\n", | |
"plt.ylim(bottom=0)\n", | |
"plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))\n", | |
"plt.grid(True)\n", | |
"\n", | |
"plt.tight_layout()\n", | |
"plt.show()\n", | |
"\n", | |
"\n", | |
"# ---------------------------------------------------------------------\n", | |
"# 6. Recall@k *against the entire corpus*\n", | |
"# ---------------------------------------------------------------------\n", | |
"@torch.no_grad()\n", | |
"def recall_at_k_corpus(model, q, docs, k=10, batch_size_eval=256):\n", | |
" \"\"\"\n", | |
" Each query has exactly ONE true doc, aligned by row index.\n", | |
" Searches the whole corpus (size == len(docs)).\n", | |
" \"\"\"\n", | |
" model.eval()\n", | |
" q_dev = q.to(device)\n", | |
" docs_dev = docs.to(device)\n", | |
"\n", | |
" logger.info(f'\\n--- Calculating Recall@{k} ---')\n", | |
" logger.info(f'Encoding {len(docs_dev)} documents...')\n", | |
" doc_vecs = []\n", | |
" encode_start = time.time()\n", | |
" for i in range(0, len(docs_dev), batch_size_eval):\n", | |
" batch_docs = docs_dev[i : i + batch_size_eval]\n", | |
" doc_vecs.append(model.encode(batch_docs))\n", | |
" doc_vecs = torch.cat(doc_vecs, 0) # (M, dim)\n", | |
" doc_mat = doc_vecs.t() # (dim, M) transpose for efficient matmul\n", | |
" encode_end = time.time()\n", | |
" logger.info(f'Document encoding took {encode_end - encode_start:.2f}s')\n", | |
"\n", | |
" logger.info(f'Performing search for {len(q_dev)} queries...')\n", | |
" hits = 0\n", | |
" search_start = time.time()\n", | |
" for i in range(0, len(q_dev), batch_size_eval):\n", | |
" batch_q = q_dev[i : i + batch_size_eval]\n", | |
" qv = model.encode(batch_q) # (B, dim)\n", | |
" # Calculate similarity scores (dot product since vectors are normalized)\n", | |
" sim = qv @ doc_mat # (B, M)\n", | |
" # Get top K indices for each query\n", | |
" topk_indices = sim.topk(k, dim=1).indices # (B, k)\n", | |
" # Create target indices corresponding to the\n", | |
" # true positive document for each query in the batch\n", | |
" # The true positive doc for query `j` (absolute index) is `docs[j]`,\n", | |
" # so the target index in the `doc_mat` is `j`.\n", | |
" # For a batch starting at index `i`, the targets are `i, i+1, ..., i+B-1`.\n", | |
" target_indices = torch.arange(i, i + qv.size(0), device=device).unsqueeze(\n", | |
" 1\n", | |
" ) # (B, 1)\n", | |
" # Check if the target index is present in the top K indices for each query\n", | |
" # Broadcasting (topk_indices == target_indices) compares each element of\n", | |
" # topk_indices with the target_index for that row\n", | |
" hits += (topk_indices == target_indices).any(dim=1).sum().item()\n", | |
" search_end = time.time()\n", | |
" logger.info(f'Search took {search_end - search_start:.2f}s')\n", | |
"\n", | |
" recall = hits / len(q_dev)\n", | |
" return recall\n", | |
"\n", | |
"\n", | |
"# Use the device-specific tensors for recall calculation\n", | |
"rec10 = recall_at_k_corpus(\n", | |
" model, queries_dev, docs_pos_dev, k=10, batch_size_eval=64\n", | |
")\n", | |
"logger.info(f'\\nRecall@10 (full corpus): {rec10:.2%}')" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment