Skip to content

Instantly share code, notes, and snippets.

@sdbuch
Created July 24, 2025 06:34
Show Gist options
  • Save sdbuch/14eacede78e19867b064a90974c1fe69 to your computer and use it in GitHub Desktop.
Save sdbuch/14eacede78e19867b064a90974c1fe69 to your computer and use it in GitHub Desktop.
Simple class-conditional denoiser with cross attention, minimal eval
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"title:\n",
"author: Sam Buchanan\n",
"format:\n",
" html:\n",
" code-fold: show\n",
"---\n",
"\n",
"## Code imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Size Comparison:\n",
"Baseline model (no cross-attn): 605,584 parameters\n",
"Cross-attention model: 811,536 parameters\n",
"Large model (embed_dim=384): 5,355,664 parameters\n",
"\n",
"Input shape: torch.Size([2, 1, 28, 28])\n",
"Class labels: tensor([2, 6])\n",
"Baseline output shape: torch.Size([2, 1, 28, 28])\n",
"Cross-attention output shape: torch.Size([2, 1, 28, 28])\n"
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import math\n",
"from typing import Optional\n",
"\n",
"class PatchEmbedding(nn.Module):\n",
" def __init__(self, img_size: int = 28, patch_size: int = 4, embed_dim: int = 128):\n",
" super().__init__()\n",
" self.img_size = img_size\n",
" self.patch_size = patch_size\n",
" self.num_patches = (img_size // patch_size) ** 2\n",
" self.embed_dim = embed_dim\n",
"\n",
" self.proj = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)\n",
"\n",
" def forward(self, x):\n",
" # x: (batch_size, 1, 28, 28)\n",
" B, C, H, W = x.shape\n",
" x = self.proj(x) # (batch_size, embed_dim, num_patches_h, num_patches_w)\n",
" x = x.flatten(2).transpose(1, 2) # (batch_size, num_patches, embed_dim)\n",
" return x\n",
"\n",
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, embed_dim: int = 128, num_heads: int = 4):\n",
" super().__init__()\n",
" self.embed_dim = embed_dim\n",
" self.num_heads = num_heads\n",
" self.head_dim = embed_dim // num_heads\n",
"\n",
" assert embed_dim % num_heads == 0\n",
"\n",
" self.qkv = nn.Linear(embed_dim, embed_dim * 3)\n",
" self.proj = nn.Linear(embed_dim, embed_dim)\n",
" self.dropout = nn.Dropout(0.1)\n",
"\n",
" def forward(self, x):\n",
" B, N, C = x.shape\n",
" qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)\n",
" q, k, v = qkv[0], qkv[1], qkv[2]\n",
"\n",
" attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)\n",
" attn = attn.softmax(dim=-1)\n",
" attn = self.dropout(attn)\n",
"\n",
" x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n",
" x = self.proj(x)\n",
" return x\n",
"\n",
"class CrossAttention(nn.Module):\n",
" def __init__(self, embed_dim: int = 128, num_heads: int = 4):\n",
" super().__init__()\n",
" self.embed_dim = embed_dim\n",
" self.num_heads = num_heads\n",
" self.head_dim = embed_dim // num_heads\n",
"\n",
" assert embed_dim % num_heads == 0\n",
"\n",
" self.q = nn.Linear(embed_dim, embed_dim)\n",
" self.k = nn.Linear(embed_dim, embed_dim)\n",
" self.v = nn.Linear(embed_dim, embed_dim)\n",
" self.proj = nn.Linear(embed_dim, embed_dim)\n",
" self.dropout = nn.Dropout(0.1)\n",
"\n",
" def forward(self, x, kv_input):\n",
" \"\"\"\n",
" x: query input (B, N_patches, embed_dim) - from image patches\n",
" kv_input: key-value input (B, N_class_tokens, embed_dim) - from class embeddings (N_class_tokens tokens per class)\n",
" \"\"\"\n",
" B, N_patches, C = x.shape\n",
" _, N_class_tokens, _ = kv_input.shape\n",
"\n",
" # Compute Q from patches, K,V from class embeddings\n",
" q = self.q(x).reshape(B, N_patches, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
" k = self.k(kv_input).reshape(B, N_class_tokens, self.num_heads, self.head_dim).permute(0, 2, 3, 1)\n",
" v = self.v(kv_input).reshape(B, N_class_tokens, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
"\n",
" attn = (q @ k) * (self.head_dim**-0.5)\n",
" attn = attn.softmax(dim=-1)\n",
" attn = self.dropout(attn)\n",
"\n",
" x = (attn @ v).transpose(1, 2).reshape(B, N_patches, C)\n",
" x = self.proj(x)\n",
" return x\n",
"\n",
"class MLP(nn.Module):\n",
" def __init__(self, embed_dim: int = 128, mlp_ratio: int = 4):\n",
" super().__init__()\n",
" hidden_dim = int(embed_dim * mlp_ratio)\n",
" self.fc1 = nn.Linear(embed_dim, hidden_dim)\n",
" self.fc2 = nn.Linear(hidden_dim, embed_dim)\n",
" self.dropout = nn.Dropout(0.1)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = F.gelu(x)\n",
" x = self.dropout(x)\n",
" x = self.fc2(x)\n",
" x = self.dropout(x)\n",
" return x\n",
"\n",
"class TransformerBlock(nn.Module):\n",
" def __init__(self, embed_dim: int = 128, num_heads: int = 4, mlp_ratio: int = 4,\n",
" use_cross_attn: bool = False):\n",
" super().__init__()\n",
" self.use_cross_attn = use_cross_attn\n",
"\n",
" self.norm1 = nn.LayerNorm(embed_dim)\n",
" self.attn = MultiHeadAttention(embed_dim, num_heads)\n",
"\n",
" if use_cross_attn:\n",
" self.norm_cross = nn.LayerNorm(embed_dim)\n",
" self.cross_attn = CrossAttention(embed_dim, num_heads)\n",
"\n",
" self.norm2 = nn.LayerNorm(embed_dim)\n",
" self.mlp = MLP(embed_dim, mlp_ratio)\n",
"\n",
" def forward(self, x, class_embeddings=None):\n",
" # Self-attention\n",
" x = x + self.attn(self.norm1(x))\n",
"\n",
" # Cross-attention (if enabled and class embeddings provided)\n",
" if self.use_cross_attn and class_embeddings is not None:\n",
" x = x + self.cross_attn(self.norm_cross(x), class_embeddings)\n",
"\n",
" # MLP\n",
" x = x + self.mlp(self.norm2(x))\n",
" return x\n",
"\n",
"class VisionTransformerDenoiser(nn.Module):\n",
" def __init__(self, img_size: int = 28, patch_size: int = 4, embed_dim: int = 128,\n",
" num_layers: int = 3, num_heads: int = 4, mlp_ratio: int = 4,\n",
" cross_attn_config: str = \"none\", num_classes: int = 10, num_class_tokens: int = 5):\n",
" super().__init__()\n",
" self.img_size = img_size\n",
" self.patch_size = patch_size\n",
" self.embed_dim = embed_dim\n",
" self.num_patches = (img_size // patch_size) ** 2\n",
" self.cross_attn_config = cross_attn_config\n",
" self.num_classes = num_classes\n",
" self.num_class_tokens = num_class_tokens\n",
"\n",
" self.patch_embed = PatchEmbedding(img_size, patch_size, embed_dim)\n",
" self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim) * 0.02)\n",
"\n",
" # Class embeddings for cross-attention: efficient implementation of K independent embeddings\n",
" if cross_attn_config != \"none\":\n",
" self.class_embed = nn.Embedding(num_classes, num_class_tokens * embed_dim)\n",
" # Positional embeddings for class tokens (like patch positional embeddings)\n",
" self.class_pos_embed = nn.Parameter(torch.randn(1, num_class_tokens, embed_dim) * 0.02)\n",
"\n",
" use_cross_attn = cross_attn_config != \"none\"\n",
" self.blocks = nn.ModuleList([\n",
" TransformerBlock(embed_dim, num_heads, mlp_ratio, use_cross_attn)\n",
" for _ in range(num_layers)\n",
" ])\n",
"\n",
" self.norm = nn.LayerNorm(embed_dim)\n",
"\n",
" # Unembedding: project back to patch pixels\n",
" pixels_per_patch = patch_size * patch_size\n",
" self.unembed = nn.Linear(embed_dim, pixels_per_patch)\n",
"\n",
" def forward(self, x, class_labels=None):\n",
" # x: (batch_size, 1, 28, 28)\n",
" # class_labels: (batch_size,) - optional for cross-attention\n",
" B = x.shape[0]\n",
"\n",
" # Patch embedding\n",
" x = self.patch_embed(x) # (B, num_patches, embed_dim)\n",
"\n",
" # Add positional embeddings\n",
" x = x + self.pos_embed\n",
"\n",
" # Prepare class embeddings for cross-attention\n",
" class_embeddings = None\n",
" if self.cross_attn_config != \"none\" and class_labels is not None:\n",
" # Get class embeddings: (B, num_class_tokens, embed_dim)\n",
" # Efficient implementation: embed to (B, num_class_tokens*embed_dim) then reshape\n",
" class_embeds_flat = self.class_embed(class_labels) # (B, num_class_tokens*embed_dim)\n",
" class_embeddings = class_embeds_flat.reshape(B, self.num_class_tokens, self.embed_dim)\n",
" # Add positional embeddings to class tokens (like patch embeddings)\n",
" class_embeddings = class_embeddings + self.class_pos_embed\n",
"\n",
" # Apply transformer blocks\n",
" for block in self.blocks:\n",
" x = block(x, class_embeddings)\n",
"\n",
" x = self.norm(x)\n",
"\n",
" # Unembed to patches\n",
" x = self.unembed(x) # (B, num_patches, pixels_per_patch)\n",
"\n",
" # Reshape back to image\n",
" patches_per_side = self.img_size // self.patch_size\n",
" x = x.reshape(B, patches_per_side, patches_per_side, self.patch_size, self.patch_size)\n",
" x = x.permute(0, 1, 3, 2, 4).contiguous()\n",
" x = x.reshape(B, 1, self.img_size, self.img_size)\n",
"\n",
" return x\n",
"\n",
"# Example usage and model size comparison\n",
"model_baseline = VisionTransformerDenoiser(img_size=28, patch_size=4, embed_dim=128, num_layers=3, cross_attn_config=\"none\")\n",
"model_cross_attn = VisionTransformerDenoiser(img_size=28, patch_size=4, embed_dim=128, num_layers=3, cross_attn_config=\"enabled\", num_classes=10)\n",
"model_large = VisionTransformerDenoiser(img_size=28, patch_size=4, embed_dim=384, num_layers=3, num_heads=6, cross_attn_config=\"none\")\n",
"\n",
"print(\"Model Size Comparison:\")\n",
"print(f\"Baseline model (no cross-attn): {sum(p.numel() for p in model_baseline.parameters()):,} parameters\")\n",
"print(f\"Cross-attention model: {sum(p.numel() for p in model_cross_attn.parameters()):,} parameters\")\n",
"print(f\"Large model (embed_dim=384): {sum(p.numel() for p in model_large.parameters()):,} parameters\")\n",
"\n",
"# Test with random input\n",
"x = torch.randn(2, 1, 28, 28)\n",
"class_labels = torch.randint(0, 10, (2,))\n",
"\n",
"print(f\"\\nInput shape: {x.shape}\")\n",
"print(f\"Class labels: {class_labels}\")\n",
"\n",
"# Test baseline model\n",
"output_baseline = model_baseline(x)\n",
"print(f\"Baseline output shape: {output_baseline.shape}\")\n",
"\n",
"# Test cross-attention model\n",
"output_cross_attn = model_cross_attn(x, class_labels)\n",
"print(f\"Cross-attention output shape: {output_cross_attn.shape}\")\n",
"\n",
"# Use the baseline model for training (can be switched to cross-attention later)\n",
"model = model_baseline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset Setup"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training samples: 60000\n",
"Test samples: 10000\n"
]
}
],
"source": [
"import numpy as np\n",
"import torch.optim as optim\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from torch.utils.data import DataLoader\n",
"\n",
"# Download and prepare MNIST dataset\n",
"transform = transforms.Compose(\n",
" [\n",
" transforms.ToTensor(),\n",
" ]\n",
")\n",
"\n",
"train_dataset = torchvision.datasets.MNIST(\n",
" root=\"./data\", train=True, transform=transform, download=True\n",
")\n",
"test_dataset = torchvision.datasets.MNIST(\n",
" root=\"./data\", train=False, transform=transform, download=True\n",
")\n",
"\n",
"batch_size = 512\n",
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
"\n",
"print(f\"Training samples: {len(train_dataset)}\")\n",
"print(f\"Test samples: {len(test_dataset)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"─\n",
"## Noise Addition and Helper Functions"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original range: [0.000, 1.000]\n",
"Noisy range: [0.000, 1.000]\n"
]
}
],
"source": [
"def add_gaussian_noise(images, noise_std=0.2):\n",
" \"\"\"Add Gaussian noise to images\"\"\"\n",
" noise = torch.randn_like(images) * noise_std\n",
" noisy_images = images + noise\n",
" return torch.clamp(noisy_images, 0.0, 1.0)\n",
"\n",
"def evaluate_model(model, data_loader, noise_std=0.2, device='cpu', use_class_labels=False):\n",
" \"\"\"Evaluate model on test set\"\"\"\n",
" model.eval()\n",
" total_loss = 0.0\n",
" num_batches = 0\n",
"\n",
" with torch.no_grad():\n",
" for images, labels in data_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" noisy_images = add_gaussian_noise(images, noise_std)\n",
"\n",
" if use_class_labels:\n",
" outputs = model(noisy_images, labels)\n",
" else:\n",
" outputs = model(noisy_images)\n",
" loss = F.mse_loss(outputs, images)\n",
"\n",
" total_loss += loss.item()\n",
" num_batches += 1\n",
"\n",
" return total_loss / num_batches\n",
"\n",
"# Test noise addition\n",
"sample_batch, _ = next(iter(train_loader))\n",
"noisy_batch = add_gaussian_noise(sample_batch, 0.2)\n",
"print(f\"Original range: [{sample_batch.min():.3f}, {sample_batch.max():.3f}]\")\n",
"print(f\"Noisy range: [{noisy_batch.min():.3f}, {noisy_batch.max():.3f}]\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Loop"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cpu\n",
"Baseline model parameters: 605,584\n",
"Conditional model parameters: 811,536\n",
"Starting dual model training...\n",
"Epoch 1/10, Batch 0:\n",
" Baseline Loss: 0.589925\n",
" Conditional Loss: 0.605246\n",
"Epoch 1/10, Batch 100:\n",
" Baseline Loss: 0.014165\n",
" Conditional Loss: 0.014485\n",
"Epoch 1/10 Results:\n",
" Baseline - Train: 0.053855, Test: 0.009175\n",
" Conditional - Train: 0.047980, Test: 0.009853\n",
"\n",
"Epoch 2/10, Batch 0:\n",
" Baseline Loss: 0.011881\n",
" Conditional Loss: 0.012587\n",
"Epoch 2/10, Batch 100:\n",
" Baseline Loss: 0.008089\n",
" Conditional Loss: 0.008545\n",
"Epoch 2/10 Results:\n",
" Baseline - Train: 0.009310, Test: 0.006402\n",
" Conditional - Train: 0.009884, Test: 0.006695\n",
"\n",
"Epoch 3/10, Batch 0:\n",
" Baseline Loss: 0.007999\n",
" Conditional Loss: 0.008325\n",
"Epoch 3/10, Batch 100:\n",
" Baseline Loss: 0.006997\n",
" Conditional Loss: 0.007027\n",
"Epoch 3/10 Results:\n",
" Baseline - Train: 0.007326, Test: 0.005807\n",
" Conditional - Train: 0.007506, Test: 0.005720\n",
"\n",
"Epoch 4/10, Batch 0:\n",
" Baseline Loss: 0.006903\n",
" Conditional Loss: 0.006876\n",
"Epoch 4/10, Batch 100:\n",
" Baseline Loss: 0.006435\n",
" Conditional Loss: 0.006302\n",
"Epoch 4/10 Results:\n",
" Baseline - Train: 0.006622, Test: 0.005496\n",
" Conditional - Train: 0.006545, Test: 0.005232\n",
"\n",
"Epoch 5/10, Batch 0:\n",
" Baseline Loss: 0.006404\n",
" Conditional Loss: 0.006246\n",
"Epoch 5/10, Batch 100:\n",
" Baseline Loss: 0.005970\n",
" Conditional Loss: 0.005712\n",
"Epoch 5/10 Results:\n",
" Baseline - Train: 0.006207, Test: 0.005271\n",
" Conditional - Train: 0.006014, Test: 0.004947\n",
"\n",
"Epoch 6/10, Batch 0:\n",
" Baseline Loss: 0.006038\n",
" Conditional Loss: 0.005783\n",
"Epoch 6/10, Batch 100:\n",
" Baseline Loss: 0.005747\n",
" Conditional Loss: 0.005415\n",
"Epoch 6/10 Results:\n",
" Baseline - Train: 0.005912, Test: 0.005131\n",
" Conditional - Train: 0.005625, Test: 0.004747\n",
"\n",
"Epoch 7/10, Batch 0:\n",
" Baseline Loss: 0.005777\n",
" Conditional Loss: 0.005478\n",
"Epoch 7/10, Batch 100:\n",
" Baseline Loss: 0.005559\n",
" Conditional Loss: 0.005242\n",
"Epoch 7/10 Results:\n",
" Baseline - Train: 0.005674, Test: 0.004941\n",
" Conditional - Train: 0.005353, Test: 0.004561\n",
"\n",
"Epoch 8/10, Batch 0:\n",
" Baseline Loss: 0.005522\n",
" Conditional Loss: 0.005220\n",
"Epoch 8/10, Batch 100:\n",
" Baseline Loss: 0.005412\n",
" Conditional Loss: 0.005134\n",
"Epoch 8/10 Results:\n",
" Baseline - Train: 0.005451, Test: 0.004751\n",
" Conditional - Train: 0.005147, Test: 0.004432\n",
"\n",
"Epoch 9/10, Batch 0:\n",
" Baseline Loss: 0.005348\n",
" Conditional Loss: 0.005100\n",
"Epoch 9/10, Batch 100:\n",
" Baseline Loss: 0.005188\n",
" Conditional Loss: 0.004924\n",
"Epoch 9/10 Results:\n",
" Baseline - Train: 0.005254, Test: 0.004590\n",
" Conditional - Train: 0.004992, Test: 0.004340\n",
"\n",
"Epoch 10/10, Batch 0:\n",
" Baseline Loss: 0.005163\n",
" Conditional Loss: 0.004878\n",
"Epoch 10/10, Batch 100:\n",
" Baseline Loss: 0.005044\n",
" Conditional Loss: 0.004792\n",
"Epoch 10/10 Results:\n",
" Baseline - Train: 0.005085, Test: 0.004460\n",
" Conditional - Train: 0.004854, Test: 0.004259\n",
"\n",
"Training completed!\n"
]
}
],
"source": [
"# Set device\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Initialize both models\n",
"model_baseline = VisionTransformerDenoiser(img_size=28, patch_size=4, embed_dim=128, num_layers=3,\n",
" cross_attn_config=\"none\")\n",
"model_conditional = VisionTransformerDenoiser(img_size=28, patch_size=4, embed_dim=128, num_layers=3,\n",
" cross_attn_config=\"enabled\", num_classes=10, num_class_tokens=5)\n",
"\n",
"model_baseline = model_baseline.to(device)\n",
"model_conditional = model_conditional.to(device)\n",
"\n",
"print(f\"Baseline model parameters: {sum(p.numel() for p in model_baseline.parameters()):,}\")\n",
"print(f\"Conditional model parameters: {sum(p.numel() for p in model_conditional.parameters()):,}\")\n",
"\n",
"# Training parameters\n",
"learning_rate = 1e-3\n",
"num_epochs = 10\n",
"noise_std = 0.2\n",
"\n",
"# Optimizers and loss function\n",
"optimizer_baseline = optim.Adam(model_baseline.parameters(), lr=learning_rate)\n",
"optimizer_conditional = optim.Adam(model_conditional.parameters(), lr=learning_rate)\n",
"criterion = nn.MSELoss()\n",
"\n",
"# Training history\n",
"train_losses_baseline = []\n",
"test_losses_baseline = []\n",
"train_losses_conditional = []\n",
"test_losses_conditional = []\n",
"\n",
"print(\"Starting dual model training...\")\n",
"for epoch in range(num_epochs):\n",
" model_baseline.train()\n",
" model_conditional.train()\n",
"\n",
" epoch_loss_baseline = 0.0\n",
" epoch_loss_conditional = 0.0\n",
" num_batches = 0\n",
"\n",
" for batch_idx, (images, labels) in enumerate(train_loader):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" # Add noise to create input\n",
" noisy_images = add_gaussian_noise(images, noise_std)\n",
"\n",
" # Train baseline model (unconditional)\n",
" optimizer_baseline.zero_grad()\n",
" outputs_baseline = model_baseline(noisy_images)\n",
" loss_baseline = criterion(outputs_baseline, images)\n",
" loss_baseline.backward()\n",
" optimizer_baseline.step()\n",
"\n",
" # Train conditional model (with class labels)\n",
" optimizer_conditional.zero_grad()\n",
" outputs_conditional = model_conditional(noisy_images, labels)\n",
" loss_conditional = criterion(outputs_conditional, images)\n",
" loss_conditional.backward()\n",
" optimizer_conditional.step()\n",
"\n",
" epoch_loss_baseline += loss_baseline.item()\n",
" epoch_loss_conditional += loss_conditional.item()\n",
" num_batches += 1\n",
"\n",
" if batch_idx % 100 == 0:\n",
" print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}:')\n",
" print(f' Baseline Loss: {loss_baseline.item():.6f}')\n",
" print(f' Conditional Loss: {loss_conditional.item():.6f}')\n",
"\n",
" # Calculate average training losses\n",
" avg_train_loss_baseline = epoch_loss_baseline / num_batches\n",
" avg_train_loss_conditional = epoch_loss_conditional / num_batches\n",
" train_losses_baseline.append(avg_train_loss_baseline)\n",
" train_losses_conditional.append(avg_train_loss_conditional)\n",
"\n",
" # Evaluate on test set\n",
" test_loss_baseline = evaluate_model(model_baseline, test_loader, noise_std, device, use_class_labels=False)\n",
" test_loss_conditional = evaluate_model(model_conditional, test_loader, noise_std, device, use_class_labels=True)\n",
" test_losses_baseline.append(test_loss_baseline)\n",
" test_losses_conditional.append(test_loss_conditional)\n",
"\n",
" print(f'Epoch {epoch+1}/{num_epochs} Results:')\n",
" print(f' Baseline - Train: {avg_train_loss_baseline:.6f}, Test: {test_loss_baseline:.6f}')\n",
" print(f' Conditional - Train: {avg_train_loss_conditional:.6f}, Test: {test_loss_conditional:.6f}')\n",
" print()\n",
"\n",
"print(\"Training completed!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning Curves"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1500x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final Results:\n",
"Baseline - Train: 0.005085, Test: 0.004460\n",
"Conditional - Train: 0.004854, Test: 0.004259\n"
]
}
],
"source": [
"# Plot learning curves for both models\n",
"plt.figure(figsize=(15, 5))\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(range(1, num_epochs+1), train_losses_baseline, 'b-', label='Baseline Train')\n",
"plt.plot(range(1, num_epochs+1), test_losses_baseline, 'b--', label='Baseline Test')\n",
"plt.plot(range(1, num_epochs+1), train_losses_conditional, 'r-', label='Conditional Train')\n",
"plt.plot(range(1, num_epochs+1), test_losses_conditional, 'r--', label='Conditional Test')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('MSE Loss')\n",
"plt.yscale('log')\n",
"plt.title('Learning Curves Comparison (Log Scale)')\n",
"plt.legend()\n",
"plt.grid(True)\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"plt.plot(range(1, num_epochs+1), train_losses_baseline, 'b-', label='Baseline')\n",
"plt.plot(range(1, num_epochs+1), train_losses_conditional, 'r-', label='Conditional')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Training MSE Loss')\n",
"plt.yscale('log')\n",
"plt.title('Training Loss Comparison (Log Scale)')\n",
"plt.legend()\n",
"plt.grid(True)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# Final evaluation summary\n",
"print(\"Final Results:\")\n",
"print(f\"Baseline - Train: {train_losses_baseline[-1]:.6f}, Test: {test_losses_baseline[-1]:.6f}\")\n",
"print(f\"Conditional - Train: {train_losses_conditional[-1]:.6f}, Test: {test_losses_conditional[-1]:.6f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Baseline Model Denoising Results"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1600x600 with 24 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Visualize baseline model denoising results\n",
"model_baseline.eval()\n",
"with torch.no_grad():\n",
" # Get a batch of test images\n",
" test_images, test_labels = next(iter(test_loader))\n",
" test_images = test_images[:8].to(device)\n",
" test_labels = test_labels[:8]\n",
"\n",
" # Add noise\n",
" noisy_images = add_gaussian_noise(test_images, noise_std)\n",
"\n",
" # Denoise with baseline model\n",
" denoised_images = model_baseline(noisy_images)\n",
"\n",
" # Move to CPU for plotting\n",
" test_images = test_images.cpu()\n",
" noisy_images = noisy_images.cpu()\n",
" denoised_images = denoised_images.cpu()\n",
"\n",
" # Plot comparison\n",
" fig, axes = plt.subplots(3, 8, figsize=(16, 6))\n",
"\n",
" for i in range(8):\n",
" # Original\n",
" axes[0, i].imshow(test_images[i, 0], cmap='gray')\n",
" axes[0, i].set_title(f'Orig: {test_labels[i].item()}' if i < 4 else '')\n",
" axes[0, i].axis('off')\n",
"\n",
" # Noisy\n",
" axes[1, i].imshow(noisy_images[i, 0], cmap='gray')\n",
" axes[1, i].set_title('Noisy' if i == 0 else '')\n",
" axes[1, i].axis('off')\n",
"\n",
" # Denoised\n",
" axes[2, i].imshow(denoised_images[i, 0], cmap='gray')\n",
" axes[2, i].set_title('Baseline' if i == 0 else '')\n",
" axes[2, i].axis('off')\n",
"\n",
" plt.suptitle('Baseline Model (Unconditional) Denoising')\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conditional Model Denoising Results"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1600x600 with 24 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Visualize conditional model denoising results\n",
"model_conditional.eval()\n",
"with torch.no_grad():\n",
" # Use same test images for fair comparison\n",
" test_images, test_labels = next(iter(test_loader))\n",
" test_images = test_images[:8].to(device)\n",
" test_labels = test_labels[:8].to(device)\n",
"\n",
" # Add noise\n",
" noisy_images = add_gaussian_noise(test_images, noise_std)\n",
"\n",
" # Denoise with conditional model (using correct class labels)\n",
" denoised_images = model_conditional(noisy_images, test_labels)\n",
"\n",
" # Move to CPU for plotting\n",
" test_images = test_images.cpu()\n",
" noisy_images = noisy_images.cpu()\n",
" denoised_images = denoised_images.cpu()\n",
" test_labels = test_labels.cpu()\n",
"\n",
" # Plot comparison\n",
" fig, axes = plt.subplots(3, 8, figsize=(16, 6))\n",
"\n",
" for i in range(8):\n",
" # Original\n",
" axes[0, i].imshow(test_images[i, 0], cmap='gray')\n",
" axes[0, i].set_title(f'Orig: {test_labels[i].item()}' if i < 4 else '')\n",
" axes[0, i].axis('off')\n",
"\n",
" # Noisy\n",
" axes[1, i].imshow(noisy_images[i, 0], cmap='gray')\n",
" axes[1, i].set_title('Noisy' if i == 0 else '')\n",
" axes[1, i].axis('off')\n",
"\n",
" # Denoised\n",
" axes[2, i].imshow(denoised_images[i, 0], cmap='gray')\n",
" axes[2, i].set_title('Conditional' if i == 0 else '')\n",
" axes[2, i].axis('off')\n",
"\n",
" plt.suptitle('Conditional Model (Cross-Attention) Denoising')\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Per-Class Test Error Analysis"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calculating per-class test errors...\n",
"\n",
"Per-Class Test Errors (MSE):\n",
"Class | Baseline | Conditional | Improvement\n",
"------|-----------|-------------|------------\n",
" 0 | 0.005144 | 0.004942 | +0.000202 (+3.9%)\n",
" 1 | 0.002510 | 0.002312 | +0.000198 (+7.9%)\n",
" 2 | 0.004917 | 0.004758 | +0.000159 (+3.2%)\n",
" 3 | 0.004828 | 0.004662 | +0.000166 (+3.4%)\n",
" 4 | 0.004371 | 0.004185 | +0.000187 (+4.3%)\n",
" 5 | 0.004803 | 0.004592 | +0.000211 (+4.4%)\n",
" 6 | 0.004783 | 0.004596 | +0.000187 (+3.9%)\n",
" 7 | 0.003980 | 0.003791 | +0.000189 (+4.7%)\n",
" 8 | 0.005209 | 0.005017 | +0.000192 (+3.7%)\n",
" 9 | 0.004379 | 0.004194 | +0.000186 (+4.2%)\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Overall Comparison:\n",
"Baseline average: 0.004492\n",
"Conditional average: 0.004305\n",
"Overall improvement: +0.000188 (+4.2%)\n"
]
}
],
"source": [
"def evaluate_per_class(model, data_loader, noise_std=0.2, device='cpu', use_class_labels=False):\n",
" \"\"\"Evaluate model per class\"\"\"\n",
" model.eval()\n",
" class_losses = {i: [] for i in range(10)}\n",
"\n",
" with torch.no_grad():\n",
" for images, labels in data_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" noisy_images = add_gaussian_noise(images, noise_std)\n",
"\n",
" if use_class_labels:\n",
" outputs = model(noisy_images, labels)\n",
" else:\n",
" outputs = model(noisy_images)\n",
"\n",
" # Calculate loss for each sample\n",
" losses = F.mse_loss(outputs, images, reduction='none').mean(dim=[1,2,3])\n",
"\n",
" # Group by class\n",
" for i, label in enumerate(labels):\n",
" class_losses[label.item()].append(losses[i].item())\n",
"\n",
" # Calculate mean loss per class\n",
" class_mean_losses = {}\n",
" for class_id in range(10):\n",
" if class_losses[class_id]:\n",
" class_mean_losses[class_id] = np.mean(class_losses[class_id])\n",
" else:\n",
" class_mean_losses[class_id] = 0.0\n",
"\n",
" return class_mean_losses\n",
"\n",
"# Calculate per-class errors\n",
"print(\"Calculating per-class test errors...\")\n",
"baseline_class_errors = evaluate_per_class(model_baseline, test_loader, noise_std, device, use_class_labels=False)\n",
"conditional_class_errors = evaluate_per_class(model_conditional, test_loader, noise_std, device, use_class_labels=True)\n",
"\n",
"# Display results\n",
"print(\"\\nPer-Class Test Errors (MSE):\")\n",
"print(\"Class | Baseline | Conditional | Improvement\")\n",
"print(\"------|-----------|-------------|------------\")\n",
"for class_id in range(10):\n",
" baseline_err = baseline_class_errors[class_id]\n",
" conditional_err = conditional_class_errors[class_id]\n",
" improvement = baseline_err - conditional_err\n",
" improvement_pct = (improvement / baseline_err * 100) if baseline_err > 0 else 0\n",
" print(f\" {class_id} | {baseline_err:.6f} | {conditional_err:.6f} | {improvement:+.6f} ({improvement_pct:+.1f}%)\")\n",
"\n",
"# Plot comparison\n",
"plt.figure(figsize=(12, 5))\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"classes = list(range(10))\n",
"baseline_errors = [baseline_class_errors[i] for i in classes]\n",
"conditional_errors = [conditional_class_errors[i] for i in classes]\n",
"\n",
"x = np.arange(len(classes))\n",
"width = 0.35\n",
"\n",
"plt.bar(x - width/2, baseline_errors, width, label='Baseline', alpha=0.8)\n",
"plt.bar(x + width/2, conditional_errors, width, label='Conditional', alpha=0.8)\n",
"plt.xlabel('Digit Class')\n",
"plt.ylabel('MSE Loss')\n",
"plt.title('Per-Class Test Error Comparison')\n",
"plt.xticks(x, classes)\n",
"plt.legend()\n",
"plt.grid(True, alpha=0.3)\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"improvements = [baseline_class_errors[i] - conditional_class_errors[i] for i in classes]\n",
"colors = ['green' if imp > 0 else 'red' for imp in improvements]\n",
"plt.bar(classes, improvements, color=colors, alpha=0.7)\n",
"plt.xlabel('Digit Class')\n",
"plt.ylabel('Error Reduction (Baseline - Conditional)')\n",
"plt.title('Per-Class Improvement')\n",
"plt.grid(True, alpha=0.3)\n",
"plt.axhline(y=0, color='black', linestyle='-', alpha=0.5)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# Overall comparison\n",
"overall_baseline = np.mean(list(baseline_class_errors.values()))\n",
"overall_conditional = np.mean(list(conditional_class_errors.values()))\n",
"overall_improvement = overall_baseline - overall_conditional\n",
"print(f\"\\nOverall Comparison:\")\n",
"print(f\"Baseline average: {overall_baseline:.6f}\")\n",
"print(f\"Conditional average: {overall_conditional:.6f}\")\n",
"print(f\"Overall improvement: {overall_improvement:+.6f} ({overall_improvement/overall_baseline*100:+.1f}%)\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Experimental Cross-Class Conditioning"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1600x800 with 32 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Cross-Class Conditioning Analysis:\n",
"Digit 7: Correct=0.002774, Wrong=0.002977, Diff=+0.000203\n",
"Digit 1: Correct=0.001754, Wrong=0.001750, Diff=-0.000004\n",
"Digit 1: Correct=0.002420, Wrong=0.002650, Diff=+0.000230\n",
"Digit 1: Correct=0.003456, Wrong=0.003633, Diff=+0.000176\n",
"Digit 7: Correct=0.004570, Wrong=0.004869, Diff=+0.000299\n",
"Digit 7: Correct=0.002819, Wrong=0.003031, Diff=+0.000212\n",
"Digit 1: Correct=0.001746, Wrong=0.001936, Diff=+0.000190\n",
"Digit 7: Correct=0.005002, Wrong=0.005136, Diff=+0.000134\n",
"\n",
"Average - Correct: 0.003068, Wrong: 0.003248\n",
"Wrong conditioning increases error by +0.000180 on average\n"
]
}
],
"source": [
"# Experiment: Condition model on wrong class labels to see learned effects\n",
"model_conditional.eval()\n",
"\n",
"# Find some 7s and 1s in test set for the experiment\n",
"sevens_and_ones = []\n",
"for images, labels in test_loader:\n",
" for i, label in enumerate(labels):\n",
" if label.item() == 7 and len([x for x in sevens_and_ones if x[1] == 7]) < 4:\n",
" sevens_and_ones.append((images[i:i+1], label.item()))\n",
" elif label.item() == 1 and len([x for x in sevens_and_ones if x[1] == 1]) < 4:\n",
" sevens_and_ones.append((images[i:i+1], label.item()))\n",
" if len(sevens_and_ones) >= 8:\n",
" break\n",
" if len(sevens_and_ones) >= 8:\n",
" break\n",
"\n",
"with torch.no_grad():\n",
" fig, axes = plt.subplots(4, 8, figsize=(16, 8))\n",
"\n",
" for idx, (image, true_label) in enumerate(sevens_and_ones):\n",
" image = image.to(device)\n",
" noisy_image = add_gaussian_noise(image, noise_std)\n",
"\n",
" # Denoise with correct label\n",
" correct_label = torch.tensor([true_label]).to(device)\n",
" denoised_correct = model_conditional(noisy_image, correct_label)\n",
"\n",
" # Denoise with wrong label (7->1, 1->7)\n",
" wrong_label = torch.tensor([1 if true_label == 7 else 7]).to(device)\n",
" denoised_wrong = model_conditional(noisy_image, wrong_label)\n",
"\n",
" # Move to CPU\n",
" image = image.cpu()\n",
" noisy_image = noisy_image.cpu()\n",
" denoised_correct = denoised_correct.cpu()\n",
" denoised_wrong = denoised_wrong.cpu()\n",
"\n",
" # Plot\n",
" # Original\n",
" axes[0, idx].imshow(image[0, 0], cmap='gray')\n",
" axes[0, idx].set_title(f'Original: {true_label}' if idx < 4 else '')\n",
" axes[0, idx].axis('off')\n",
"\n",
" # Noisy\n",
" axes[1, idx].imshow(noisy_image[0, 0], cmap='gray')\n",
" axes[1, idx].set_title('Noisy' if idx == 0 else '')\n",
" axes[1, idx].axis('off')\n",
"\n",
" # Correct conditioning\n",
" axes[2, idx].imshow(denoised_correct[0, 0], cmap='gray')\n",
" axes[2, idx].set_title(f'Cond: {true_label}' if idx < 4 else '')\n",
" axes[2, idx].axis('off')\n",
"\n",
" # Wrong conditioning\n",
" axes[3, idx].imshow(denoised_wrong[0, 0], cmap='gray')\n",
" axes[3, idx].set_title(f'Cond: {wrong_label.item()}' if idx < 4 else '')\n",
" axes[3, idx].axis('off')\n",
"\n",
" plt.suptitle('Cross-Class Conditioning Experiment: 7s vs 1s\\n(Row 3: Correct class, Row 4: Wrong class)', fontsize=14)\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"# Quantitative comparison for the experiment\n",
"print(\"\\nCross-Class Conditioning Analysis:\")\n",
"with torch.no_grad():\n",
" # Calculate MSE for correct vs wrong conditioning\n",
" correct_losses = []\n",
" wrong_losses = []\n",
"\n",
" for image, true_label in sevens_and_ones:\n",
" image = image.to(device)\n",
" noisy_image = add_gaussian_noise(image, noise_std)\n",
"\n",
" # Correct conditioning\n",
" correct_label = torch.tensor([true_label]).to(device)\n",
" denoised_correct = model_conditional(noisy_image, correct_label)\n",
" correct_loss = F.mse_loss(denoised_correct, image).item()\n",
" correct_losses.append(correct_loss)\n",
"\n",
" # Wrong conditioning\n",
" wrong_label = torch.tensor([1 if true_label == 7 else 7]).to(device)\n",
" denoised_wrong = model_conditional(noisy_image, wrong_label)\n",
" wrong_loss = F.mse_loss(denoised_wrong, image).item()\n",
" wrong_losses.append(wrong_loss)\n",
"\n",
" print(f\"Digit {true_label}: Correct={correct_loss:.6f}, Wrong={wrong_loss:.6f}, Diff={wrong_loss-correct_loss:+.6f}\")\n",
"\n",
"print(f\"\\nAverage - Correct: {np.mean(correct_losses):.6f}, Wrong: {np.mean(wrong_losses):.6f}\")\n",
"print(f\"Wrong conditioning increases error by {np.mean(wrong_losses)-np.mean(correct_losses):+.6f} on average\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment