Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Created April 16, 2024 20:00
Show Gist options
  • Save ariG23498/2f02dbd3a6417b66570c56fdaa3641aa to your computer and use it in GitHub Desktop.
Save ariG23498/2f02dbd3a6417b66570c56fdaa3641aa to your computer and use it in GitHub Desktop.
autoregressive-diffusion-lstm.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyM5ZVCOxhE+W71bDZtd3Ii+",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/2f02dbd3a6417b66570c56fdaa3641aa/autoregressive-diffusion-lstm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"This notebook is heavily inspired from: https://huggingface.co/blog/annotated-diffusion"
],
"metadata": {
"id": "02vydzVb6VWO"
}
},
{
"cell_type": "markdown",
"source": [
"## Setup and Imports"
],
"metadata": {
"id": "-CFKktKV2qqo"
}
},
{
"cell_type": "code",
"source": [
"!pip install --upgrade -qq datasets"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BmiGcVlQ64dD",
"outputId": "b5078cae-cbd3-4b61-ad7b-6eb403a6e205"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "hzUHzDmy2kOy"
},
"outputs": [],
"source": [
"import random\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from datasets import load_dataset\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from torchvision import transforms"
]
},
{
"cell_type": "markdown",
"source": [
"## Configurations"
],
"metadata": {
"id": "rlNouEPu7ART"
}
},
{
"cell_type": "code",
"source": [
"B = 128\n",
"H = W = 28\n",
"C = 1\n",
"T = 100\n",
"\n",
"hidden_dim = 128\n",
"epochs = 5\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"{device=}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3HNObPXp7Bff",
"outputId": "f8a7e55c-ce8d-4f62-8265-94d0d16cf895"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"device='cuda'\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Dataset and Loaders"
],
"metadata": {
"id": "JGEzhe8s7B1p"
}
},
{
"cell_type": "code",
"source": [
"# load dataset from the hub\n",
"dataset = load_dataset(\"fashion_mnist\")"
],
"metadata": {
"id": "PdIqVYuj6axU"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# define image transformations\n",
"transform = transforms.Compose([\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Lambda(lambda x: (x * 2) - 1)\n",
"])\n",
"\n",
"# define function\n",
"def transforms(examples):\n",
" examples[\"pixel_values\"] = [transform(image.convert(\"L\")) for image in examples[\"image\"]]\n",
" del examples[\"image\"]\n",
" return examples\n",
"\n",
"transformed_dataset = (\n",
" dataset\n",
" .with_transform(transforms)\n",
" .remove_columns(\"label\")\n",
")\n",
"\n",
"# create dataloader\n",
"dataloader = DataLoader(\n",
" transformed_dataset[\"train\"],\n",
" batch_size=B,\n",
" shuffle=True\n",
")"
],
"metadata": {
"id": "SbnfSuvR6eel"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Get a batch of data and check the shape\n",
"batch = next(iter(dataloader))\n",
"print(batch.keys())\n",
"print(batch[\"pixel_values\"].shape)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GneDUV5J74KJ",
"outputId": "a4d06196-84e2-439a-b314-1a7b7e38faf3"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"dict_keys(['pixel_values'])\n",
"torch.Size([128, 1, 28, 28])\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Forward Diffusion Process"
],
"metadata": {
"id": "Ggg6okjO8NFS"
}
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "9_7ipOKVIbFn"
},
"outputs": [],
"source": [
"# Define a linear schedule\n",
"def linear_beta_schedule(T):\n",
" beta_start = 0.0001\n",
" beta_end = 0.02\n",
" return torch.linspace(beta_start, beta_end, T)\n",
"\n",
"# define beta schedule\n",
"betas = linear_beta_schedule(T=T)\n",
"\n",
"# define alphas = 1 - beta\n",
"alphas = 1.0 - betas\n",
"alphas_cumprod = torch.cumprod(alphas, axis=0)\n",
"sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)\n",
"sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)\n",
"\n",
"# sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).repeat(batch_size, 1)\n",
"# sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).repeat(batch_size, 1)\n",
"\n",
"def extract(a, t, x_shape):\n",
" batch_size = t.shape[0]\n",
" out = a.gather(dim=-1, index=t.cpu())\n",
" return out.reshape(\n",
" batch_size, T, *((1,) * (len(x_shape) - 2))\n",
" ).to(t.device)"
]
},
{
"cell_type": "code",
"source": [
"# forward diffusion\n",
"def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None):\n",
" if noise is None:\n",
" noise = torch.randn_like(x_start)\n",
"\n",
" sqrt_alphas_cumprod_t = extract(\n",
" sqrt_alphas_cumprod,\n",
" t,\n",
" x_start.shape\n",
" )\n",
" sqrt_one_minus_alphas_cumprod_t = extract(\n",
" sqrt_one_minus_alphas_cumprod,\n",
" t,\n",
" x_start.shape\n",
" )\n",
"\n",
" return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise"
],
"metadata": {
"id": "pmHZGKa8Lhbz"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# create a batch of noise images\n",
"t = torch.arange(0, T).flip((-1,)).repeat(B, 1)\n",
"input_images = q_sample(\n",
" x_start=batch[\"pixel_values\"].unsqueeze(1), # (B, 1, C, H, W)\n",
" t=t, # (B, timesteps)\n",
" sqrt_alphas_cumprod=sqrt_alphas_cumprod.repeat(B, 1),\n",
" sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod.repeat(B, 1),\n",
")"
],
"metadata": {
"id": "uWgxpZPj8sY6"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Viz the forward process"
],
"metadata": {
"id": "aQ0H8DVxEJqF"
}
},
{
"cell_type": "code",
"source": [
"idx = random.randint(0, B-1)\n",
"for i in range(5):\n",
" plt.subplot(1, 5, i+1)\n",
" plt.imshow(input_images[idx, T//5 * i].permute(1, 2, 0), cmap=\"gray\")\n",
" plt.axis(\"off\")\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
},
"id": "I-hoBcj3DNSf",
"outputId": "f4dd8b4d-4dff-4a2d-c8b9-11594daa2318"
},
"execution_count": 20,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 5 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAABpCAYAAABF9zs7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApZ0lEQVR4nO3dd7RdVdU28An2XrA3RGxYAQkigpAYIqIYSkIXCEgAKaFEAyIIgg5EIEggFEMIhBokEppUJZhIE8Gu2LvYe1fev97pb+1xdsa999wx3jG+bz5/zXFy7t5rr7nWOjvPM8tqDz300ENRKBQKhULh/2us/n89gEKhUCgUCv/3qBeCQqFQKBQK9UJQKBQKhUKhXggKhUKhUChEvRAUCoVCoVCIeiEoFAqFQqEQ9UJQKBQKhUIh6oWgUCgUCoVCRDx8pF984xvfmPb222+f9ic/+cm0995777SXLl2a9tvf/va0f/vb36b9y1/+srnHM57xjLTvvPPOgZ//9Kc/TXujjTZK+/DDD0/7gQceSPuCCy5I+1Of+lTakydPTvtFL3pR2r/5zW/S/uc//5n2v//972as//nPf9L+9a9/nfZPfvKTtLfbbru0b7jhhrTXXHPNtE877bQYFs7D1KlT07766qvT3mOPPQZ+/ta3vjXt3/3ud2l3ffP0pz897XvuuSftPt9suOGGac+ZMyft73znO2kvXLgwbX0zceLEtF/4whem7dr517/+NdCOiLDWln+jb97xjnekffPNN6f9/Oc/P+358+fHMFhvvfXSnjJlStq33HJL2tOmTRs4js033zztP/zhD2m71iIi1lhjjbS/+MUvDvz8wQcfTHvddddNW7/ouwULFqR90003pb3xxhun/bznPS9t1437ZFV7xmf6+c9/nvYWW2yR9ooVK9J+znOek7Z7eqxYZ5110n7DG96Q9h133DFwLH4+YcKEtP/85z+n/fvf/765x5Oe9KS0PZP8XH++/OUvT/s973nPwOvqm8985jNpv/a1r037mc98Ztp/+tOf0l6Vb9wz/s2vfvWrtJ2nL3zhC2l7Bni2jAXPfe5z037JS16S9re//e20X/3qV6f9rW99K23P8b/+9a9p/+1vf2vu8djHPjZt1/3jHve4tP/4xz+m7dqbPXt22quv/t//T3/84x9P+3Of+1zanl9PfOITB47Pue/6Rfgcju+lL31p2t/73vcG3u/ee+/tva4ohqBQKBQKhUK9EBQKhUKhUIhYbaS9DN73vvcN/PzJT35y2k94whMG2lKk0h7rr79+c63bb7897Z122intRzziEWmfeuqpaZ9wwglpS7UsWbIk7be97W1pS7N8//vfT/vxj3/8wOeR7rnxxhubsX79619PW8psv/32S9vn/sUvfpH2ox71qLQ/8pGPxLCQXhRSkz6j9qc//em0//GPf6QtBRkRsXLlyrSnT5+etr6ZN29e2h/60IcGXlffSJc7P9/97nfTlt7TN8out956azPWb37zm2krdShpSbf2+Wbu3LkxDPbdd9+0V1tttbTdGz6flKUU9d///ve0X/GKVzT3+PznP5+2a/3hD/+vGrho0aK0jz322IHX1S+777572lKQyj2OW2pSelXKP6L161Of+tS03es+t/Kd60zafKxwDQt98OhHPzrtxzzmMWkrzSgrSllHRHz1q19N+01velPa+kbJ9aijjkpbSvnKK69M2/PFcUidO1afx3Pqvvvua8b64x//OG3PjS233HLg/ZSJ9I1S8Vig/Cncl4985CMH3vuHP/xh2sqISigR7bO+7GUvG3itu+++O+1DDjkkbX9Dli1blra+87fiG9/4xsBn0Hb/+LsU0e4B95zn8w9+8IO0//KXv6StpKHMuyoUQ1AoFAqFQqFeCAqFQqFQKIxCMrj00kvT/trXvpa2lMvll1+e9pFHHpm29Nd1112XthRWREtlS7d7D+ksqRajeI2Qlz6T0nvBC16Qts9mJLTR2VJnES3davTt/fffn7aSw89+9rO0jZhdvHhxDAspYaUMo3SlHY2U1TdSXVLtES31eNtttw28h5S+NJ1zYtSyvnnlK1+Z9lOe8pS0P/GJT6QtDaiPjQaPaGUpKcgvf/nLaet/fSNF7roYC0455ZS09flaa62VttkVM2fOTNtnUEpzbiL6ZQbvoV/WXnvttJWBjHaXmtxggw1iENzHHiFeZ5NNNmn+RunIbAclHiUHfSwdLFU7VrgHvLZR7koeylvufaO3PZsiWsnhS1/6Uto+o9KJc2JWklkG+sZsHDM1li9fHoMgzW8GTES7tzxvpbA9E8zecf94NowFZqS5llz3rpfXv/71afsMylOeMxGtNPCjH/0obffJ0572tLTNxDr//PPTdo/q6x122CHtz372s2l7NrtnpPmVQiPa36O+veHvoFkvyg3ee1UohqBQKBQKhUK9EBQKhUKhUBiFZCDFZvEWI08nTZqUtlSTVNq1116bdvfWFj+SvrnrrrvSlr6RtpIOk0LpozCNLJeGs3iN9HY3SlMKS9pdKtpIUiNbpW2PP/74GBazZs1KW8pTinyzzTZLW4rUuZKmlKqKaKl3n90o96233jptfaP00ucbaX/HZ0SxtKi+sUhKRLt2pPWcG32jz5WVhs0A2WuvvdJ+1rOelbaRx86rRVKUz6Rhu3vGTB2zF/T9Nttsk7b7xOhk76fv9KnjNlLevSRN6Rgi2nNA2t25kc71up45Z511VgwL50R6WKrZSG4pWs8Lz4Wubyx+pG+kvM120DfKst7Pgk5mZ/SNQ38oPShhRbT+1zfOjXvR+Xj2s5+dttkqY4FFn5wz14JnrPPh+PRjF54DygmeA7vuumvaFmpSAvAM8Tdg//33T/v6668f+B2pfW3PuIhWilAS0V/uRcfqOaicuCoUQ1AoFAqFQqFeCAqFQqFQKIyil4H1o4UUqzWVlRKkjK3vb1GOiIgrrrgibQuwGBErrSMF5veNspVOcRzSzFJNUlN+3o3uthCSkaHWI7eQizSedPd4wGe3AI6+MRtAKtfntb7/Nddc09xDX0mFvuY1r0lbukuq0nEYuW3hDP/2K1/5StpKF86b1LLUWERLrVlzv6/mt+j2ChgG3k9IeZoNICXvsyqnmH0T0co8L37xi9N2HfpM+tu5UfrRL+4ZpRypWqlu91s36r4vYto58B5+XxljPOCedc84LqllJVDXdl8fhIi2pr0+N/vIQjfuAbN0+iLE9Y1UsWekftKvyhAR7Xlm9pdz0KWzB91jWChRCM9fMzOULqTe3XsWbYroL1zmvT2vfW7XtNkVrlX3mOP298CMG/9WiaA7DuVW58D9J5QoRopiCAqFQqFQKNQLQaFQKBQKhVFIBtJk0h3Sweedd97Az6VNpEcOOOCA5h5S9NYrlxKRnpL2NRPhYQ97WNoWvJEqNwLcngjKExYT6Y7VbAnnQ1pOGk8KsEsLDQvpVGurS4vaMravaI2+sUhOREvRX3jhhWnrD8dhoSgL4OgbaW0lJn1zzDHHpG3Ut7JCd6wW++nzjfSitOF4+sZ12ydjWHjJSHrpQf2yyy67NPfQLxafcu2ZvWD0upKbVL9ZMBaV2nPPPdO2t4kSkrJHd6xKea5TKXuLtOgXI7rHA1L1fXS57ajNhLAYlJSuvSQi2jm134ZrT8pb3/i513Ht6Bvbm7/3ve9NW18aeb/VVls1YzVLwfnQN1LQnumOb1iYFeSe6cugcZ+YMaBfLPgU0Z5B9pvw3p5rUv1m5ngdKXzlaP1i3xgzM/rak0e0Puvzi3tJKWgsfimGoFAoFAqFQr0QFAqFQqFQGIVkcO6556YtfW5hHaloa/8bOWxU7vvf//7mHkbLS4etscYaaUvdSSNJoRhhauSvhY+kA7fddtuBn1srXGoqIuLOO+9MWyrbGu9Gm1oYyHGMB6yvLX1+0kknpa18YNS6tbw33HDDtD/4wQ8295DeV0pRclAmUv6R6jISXpnHbAUL8VgYx5bFtpN1HUS09KfZMUbkG3mvP1yfw8KsGelzC+tIAVs4ywhmacSPfexjzT2khI1GNxpaKcI69VLARr6799yTrvnJkyen7R6z0FLXL17X9aSUoOxkH4Vu7f1h4T632Jk+U+azAJR0vvN58cUXN/dw3qV+lRz0kz5wXxnpryQjPayE5j6WXncvdLMMjMR3HO4lZRbXRbcl9zDwbPFc9azwLO7L2FhVG25/T/SlZ5byg99XVvBzMzuU69zHZnopPbhOHENEK/f1FdBSZvNazsFIUQxBoVAoFAqFeiEoFAqFQqEwil4GRx99dNpGbkvPW8xGuk3qUHpDSiiizWSwJrZRzEbHWs9Zim3u3LlpG1kuHXniiSemffLJJ6ctRW1rTT+PaGtdKw0YDa2M8apXvSptJZRjjz02hsWcOXPSNuJUet6a2kZSO4fSYY43op07ixRJmVof3ghX73HmmWem7fwYkX/aaaelrSQltSudvOOOOzZjtf2oBX6k47y368ueEx/+8IdjGLz73e9O2z0jzWnksVH1UrrWlne+I9pMAdtXK984H17L+bCFtnvG659xxhlpm/2x+eabp+2e9vkj2ta27nUzJfSL69fiMI5jrHDNuGekWaXLPV+kk/v8GtFS2BYpUibafffd0/YcUWYzo0mYreC+8jxQPpCy7q5tJTgj7N3H+sZCS9Lul1122cCxjhRKdmYZGMVv1L/yi+vZv+0WO/KcUyoxc0JZXNnFc6rbQ2XQ9c2WO/DAA9N2bZtx0C3Wp7Tmdc1wED6re7d6GRQKhUKhUBgx6oWgUCgUCoVCvRAUCoVCoVAYRdqhKRamN5m6pK5rWoS2mlZXb1L7Mo1Dzc40HTUVq9OpO1999dVp28tabUxN3XS7WbNmpT179uxmrOq9pu6ZHqP25FinTZsW4wn1NRt36JvXve51aesPm7boP9OvIiIOP/zwtE2F0TdqcGrEpniZampcgxqrWqXzbNyH+vTBBx/cjFUdzZgAG/hYFU5d3TiIYaHubP9572e6pbEljtt0MecsIuJd73pX2ursVlzUL17XNELXunEN9957b9pq7a65+fPnp21lthkzZjRj9W9MxzLVTE1XfX7KlCkxnjCuyf3rPdXcnVsry/kdz5eItpGbaa7GGfXNifr0zjvvnLbnS9857Lq7/PLL07Z5WXedm8pn3I3nifPkWth4441jvOB54px5PzV3Y5vU901P9xki2t8gUwT9TfD3zkqhnp2mK7u//e3yGYx3UNP33J0wYUIzVufDmBTXiv4y9sS1OVIUQ1AoFAqFQqFeCAqFQqFQKIxCMjAlxvQOKT4bCS1btixtKXKpOtOVIiL222+/tE2LkbIxBUdqTBqvr6KT6SpSr6bv+H1pOCtARbSUz3333Ze2zTJMkZQClO6ZNGlSDIu+alvSWFYTU0axSqO+2XTTTZt7SCl/4AMfSNusVSsBSr+5dqTibE6lrcSgb5Q3lGBM44qIuPvuu9O2Op4Uq5U09Z++6TaAGS2kC/WLzyGtd8stt6QtRW7apnRnRJu2euihh6atX6SypR0dn7KC1LC2so57QxlCv5h6HNE2sHKfuC9Nf/M7riGbLI0V0sumcCmHmVqnvCJFbnpoN1XXipS77bbbwHE4J549jk8633lQxlCeVJLyXPRz92FES6vrD/2v7Oh3lAe71WdHC/d7X8q4MqK/P+4lZZnu2e350lc11nPA9e26l843xU/fOf/KHn7fz/1Ni2jXo75371oh2Psph4wUxRAUCoVCoVCoF4JCoVAoFAqjkAykXqWWpdWM4pa+6aNTulG59nO//fbb0+6rWOV1zT649NJL07bymP3KHYeVuaT2pOqswtgdh5Gh9iK3171Vz7rVqIaFjU2kMPWNkoVVsqSQnZNuZauLLroo7ZUrV6atb6SdjbR1Hs1eMFpYitxqlo7P6nKuQTNMuuMwgvewww5LWx84N66dYdHXiMbn1ndWupM6lfoz6j8i4vTTT0+7r6d9nxygFGFTLrNupPDN/nBMVvH0mjajimjpUCsa7rPPPmkrOzk38+bNi/GEso0ZGdLRSm76TNrYeVCSiog44ogj0vZZlEBtXGS0uBHpnoWeO1L4nn/uH5uquQZdKxGt36w+a6aEkpFUvc85LPqqvrrfpfNd21LvZot4FkW02RbKAfpFOUBpxqqa+lt/uY/9HXRMViBUTlQ2imjXpmtQ6fCBBx4YONapU6fGaFEMQaFQKBQKhXohKBQKhUKhMArJYIsttkhbqt6iItIYNqKROrR5hRHJEW1EpRGx0tJSetKLXkv6xuh4aS7pNmkgi8M4BumoiJaik3Jec80103bObJTk/I0HzNZQ/pCmlDp3XMo2Sh/SWxEtFeW8SLMZ7b906dKB35eKs7iQUctmOCgZSHd7ze5YlSiM4jba2OwO52Y8fWM0us9hRLdFh6TnzZSQXve5I1qKtTsP/wulo8WLFw/8vtSkDXfs2+7+lk7v84v7NqKlpaVbpUKN+rZok/M3HvDaSgB9GSDrr79+2u4rm011n9foe/9N6UV63qZNfl/q3AZRUueuI8+gkZypEa2cIG3tebvuuusOHLcSxbCQbvdMV06x6JBNvFyHfqf7rMo8/puZUeutt17aSg5Kcf5uuD7MiPB5HJPP4xi6Y9X3ZhAooehv91L3rBgJiiEoFAqFQqFQLwSFQqFQKBRGIRlYCGGnnXZK2x7sy5cvT9sa2lKhUlvd4jfnnntu2lIn73znO9O2Br0FGSZOnJi2kadXXXXVwOtIR0n/90WtdgvCiO233z5t68A7JmlG7+2YxgqpyenTp6d96623pq00IJ0vrW3xoo022qi5x8KFC9PWN2ZlSGFKY+l/59fMDdeUFK6+kerTN9363xblMaLYaG3XnvS1NJvR72OB1KFUr1kaRnu7l6Qgpa6VuiIilixZkrZ+cR24d5V1lLGcZwu3OH/6RcrS7AHHIC3fhbKVRX8ckzLSWOjPVcG15FqwSJXFfvSZ+8R17ucRrT+dF7MyLCompWyRI+fa7B/PF2l7JVP/VnlPyj+i3TPKus6BZ7fSaldOHQausXXWWSftm266KW0zA4ywN6PM4nTdwmXK3M6J60B63r2h7KjM1legT6lLudS/dQxS/l14DiiPOKa+vgYjRTEEhUKhUCgU6oWgUCgUCoVCxGoPyRWtArvsskvaRgsbmWmBBWkaI1ilsCy+0v0b6RUjLW+77ba0ja6XtpISmjlzZtpG8Vq0wYhpKX8pq24hDyO0HbcR4WeffXba1kuXPj7wwANjWOywww5pSxcaBSutK8XURy92KVqfUbpL/69YsSJtZQKpK32z9957p93nG6k4ZQ8pT7M8Itp1JY1oNsGZZ56ZthKO2TEWMhoL9LN7wDnTF9ruMe0ufD79bSS7BcOk5P2OkpYykGvY3g5Gors3lPHsRRDRFr6y+JGSj62UpdOlx21RPlZYsMfzRd+4l5UG9IfnlHZE28vB/a98oASqxOIeMyvFvaGM9+Y3vzltM0MsjqUkYTZFRCv3WfxI6fCcc85JW98qp/qdscAzV19oW2DO9aZf9Gm3p4bP58+f54ayhPdQ8nQOlRssbqb8YsaHa9txK/9GtL9Nrq9NNtkkbdeBRdrWXnvttO2TsioUQ1AoFAqFQqFeCAqFQqFQKIwiy6Cv9aU0l9SflI2R/lLUXdrKyGhbh/q50bennHJK2lLiRphecsklaUsHG61rMRqj7o1g7bY2lX4zo0Iq0r4G0rYWzhkPWDDENsJGu0pdSV9bw76vTndExFve8pa0pQWll63PfcIJJ6StP6TrFi1alPbkyZPTloI2Ql4fSFlbnCSipQjtJ2HRmwMOOCBto9xXRc+PFs6/UdnKHVJ/rgspPteaEkpES2tfeOGFac+YMSNtKc/jjz8+bYvOuH8WLFiQtvtBmliKuW98+jqiXR+uO4vA7LzzzmlbO99zYzzgmrHoltKA45Iqdr24tt17ERHbbLNN2q5DC3I5d8cdd1zaRpv77NLOnslGnSsZKG3aj6a7zi2wZDaOtLO0uHJTVyoZBu5dJSbn2bXq74yyl98x+y2inTcl6NmzZ6ft/Bx99NFpW+TL89Izq6/olQWOHKtyd9cvngnKS2Y+9GVUKOGPFMUQFAqFQqFQqBeCQqFQKBQKo5AMpKT6IiSl8KUapeGMIjYSPaKNIrc9rZS4VLb0kjX8pfqkMI3Gl76x6IjFfKR8bVMa0c6HtK0UltGfju/aa69N27rxY4URsRam0Ddz585NW99ITVoYSio7op2jZcuWpW0UshKLdLS0mXSaEbTbbrvtwGcw68M5t0iK94poqVuLxRjprcxjASGLyZihMhZIRRs9LAXsfnAu11hjjbSVAroFsqQqjTA2A0dpTqrRqGqzD5w/pSL9eP311w/8vnvGrJWIdi9KbUqZurY8Z6R258yZE8NCKUpqWvr7/PPPT1s5tK+dt+dUREsLS8ObfWIBHDMLHJNSi3OtDCh9bYtsswyUBYzaj2jPVZ/Pc0sZyzNHOXRYWMjH9ebz+dvgWe95Z3Et271HtPKdhaiUwdwbZkm5Js2I8Czzfquv/t//c991111p6yNlGQtmda9rdorP4BngmdNtpTwSFENQKBQKhUKhXggKhUKhUCiMQjKQtjKCU0pXmlPaxChPizZYnzqipW+ktS26IVUlbWzkrwVypEKloKSPDz300LSlEo3KldaJaClAqUGfW0pP6mg8KbaIlgrUN47RKGnH6HeUOKRoI9o2ws6p1LTRzdKiUlr6yWhfaWPpTAs32eq3b84jWmrOyFxrgUvPe299PiyM7pc+d3xG4rvGlGhOOumktG2LHNFSvc6n+0fJwB4CzqFrUglM6Je99torbfekGUjOd0RL9SovSb3aw8H6/I51PKAEaNaNtfuVovSl69mzo5uZ41wopUhzS+uaBSMl7Bo2Q8Uocuvh21tFmUa6Wxk3om1R73rRN8pvyrU+57BwP/h8fbX7tc2CUeryuSNamcZ5cL3ecMMNafs745r0vHM9eQYrAynDKhtaHMlzOqL1vVKqa9YWy64Dn3OkKIagUCgUCoVCvRAUCoVCoVAYhWRg9KNR2VIzygGXXXZZ2lK60r5eJ6KlcY2mnTdvXtpSNtJqRk9L/ZhNoG00phLDHnvskbaRplKcERFbbrll2lOmTElbOtFCSPvvv3/a3eyKYWFkrtHJkyZNStuWx0ZGW6BFSqpbU1s62nk/9dRT09Y3+k+6XFpcGtAIXyNrtXfdddeB9/I7EW2RI2vg2/PAAlIHH3xw2t3simHgGI3Wlho2GtpiPVKh7pOuX8za0C9m/LiOpSelLaVO+7JWlI2korfbbru0lV+6Uc7OrXvd6ypXHHPMMWl3W1wPC+VD58dMASPBldaUEixq1j3PPIec6z7fWOxHqlmZQN+475WSpPDdC8qkrpWI9rktACZlbd+M8847L+1u2+dhYCElbWlxf0OUCZxj/euZGNE+h1LuRz/60bSd277eLV7H3xZp+wsuuCBtz0RlI/vG+LsU0e5LM8I88+zFovTn+EaKYggKhUKhUCjUC0GhUCgUCoVRtD/uq78tPSKdIu0kDWIEptHWES0tbX1m6VaLM0hFS9cZKWwUsJSs15ESknqzRauSRPffvJZFLKStLBCi/NDNtBgLLKDjuPSNBZOkZfWNBUm6mRBGf5tNom+ki/WNtKrR4gcddFDaRjlLtTtu5RijnLu9IVw7+kZbStD2slLhZtCMBbYMd0z6xUI4zrF+kQp1bUe0vnRv+XyuSa/ld1yfZhCY+eMec687Jp+zW5fdCGrpYMenJKIvjOg2A2OssJWz89YnqVgQyuht12c32t7MLPeD69tndy14dnjdadOmpa18IGWt9OAcum+7/QccnxS0a9U+CmYiKPe5XsYCqXQj6R2H68j+CvpFKaCbZaAc5z2k7l0HFkXyDHHOHYdza7t3n8E5c/15r4g208J14/j8PfHccL92+wb1oRiCQqFQKBQK9UJQKBQKhUJhFFkG0iPSWUbSSm9IDRsla718i15EtHTJkUcembbtJy3+YaS447NFpTXrpVmk0swY6KtZ3qXQpa36Wl9KR0mnW9hnPCB1qG+M6Nc3Rg6bZWDtdosdRbQRykZ/H3HEEWlLvRtRLo1lDXszH4QRtBal2mCDDdKW4jR7IKKVd/Sb0dDKRPpmPKPZfW5pWCOHnVepP+dpyZIlaXdbPUuNGr1uVovR/j6f0drKgBZlcXxGSUuh9/Uv6BYT8t+UGcxC8gyQTu8rljRWSCPrGwvMSL1LuSqLKPl1o7r9+8WLF6dtO3f3rjKrdK9Uvy2hnSspaM9b51YJpJs1Jc3t2eh8uBb0jfcbFspY7kvPW2n7PvnGzJVuG25h/xwzX/qKOOkXJTHnzM89j13/Sp7Ovb9jEe1a08f+Nvm7pixhVsJIUQxBoVAoFAqFeiEoFAqFQqEwCsnACEyL90gLGpUt5WxE8ezZswf+bURLce+zzz5pn3766WnbLtiCExbaMGpTakxax4yGs88+O22LrDgeI2wjWrpVSPU6H15r2EjcLqTWjGxXGrCwihSvfztr1qyBfxvRUtUzZsxI27nbaaed0jazwAJE0vlKA9LaFklZsGBB2lOnTh04nq5vLKYjpa5vbHPttaQah4Vzu80226TtOjTTxghr/9aof/82on0m23tfcskladtXwqwQKUhpTveSvlPiU4qzMJf0arfIin1CpNPNjDHbxGjw7nocFtLOFvDyeY0Qty+JdLl+NRMion0ue0jYOtpeIBZ/k7KW6nefuM6dK9tgWwRL+r97fnWL9/wvpOqVeZSYxtJmtw/OreeU86F01Sc3KY05TxEtdW+fAv3t75e0v/tHicK9pATlvcwWcS8pC3QLj42kcJXP529f32/UqlAMQaFQKBQKhXohKBQKhUKhMArJQErRyEyjZHfccce0rTF9yy23pG2xDv82oqUUjS43wtco20MOOSRt6Syj++0nIE0jtWlEfF+Uc7ftrNGcUrUXXXRR2kaCSkVKP44HpBSdHyNWlUKk+KRozRiRAotofWMvBull6W/bFrsWjKiVRpVuNUvEyF994/x3swz8Nylde230ZYAYPT8svIfZDlLpZrhI/d1xxx1pW8e9W2TFqGezV5RRnB/lByUb15BFVtyjRtEr6/S14dXvES1Nbw+UZcuWpW3BI+fJ+40HpFyVI1zn0u0+l7KS1LkyWURLw0tN62fv7Vr1vHDPOA9SwvZXkO7u8420e0S7BzyfbrzxxrT1n9/3TB8W+t/zy+wD7+czeYZ49krhR7S0v5Kb9LznmvtKn7qG3KN+py+zQL8I119EK7crO3vOK+f7m1VZBoVCoVAoFMaEeiEoFAqFQqFQLwSFQqFQKBRGEUNgiqDpNGuttVba6p7qLvbkvvLKK9NW/4lotR7T2dRO1L5MrTLGwdQNm7aoT6k3WQlMncf0HVN2Itq0RXUe/950FfVFPx8PqNXqG3Ut9f2NN9447YkTJ6Z91VVXpd3teaVOtXDhwrRN2TPW47LLLku7rwmVKXRWQVM/tcqYc3vzzTen3fWNeq2+UWdVL1RT7OqNw8CUV9eYuqV6tLEs6riuwy7Ucp1zY2Scz2uuuSZt4wbUPY0xcc8YV2L6qPNqeqv6ekSbIqZfTJUy9U7d2O+PB9SITa1zjek/42bU8VesWJG2Ka4R7flm9UfXq3O6fPnytI0bcB6M6XDPGJfgWei6u/fee9PWZxHtczvXrpG+M2w8faOG7nMYE2B8jPEErnljM7oNg1zT999/f9rOifNpdVxTDW025G+O6YH6zrgB17y/aX7eHavz7OfC73g2jBTFEBQKhUKhUKgXgkKhUCgUCmOsVCjtK20ibHixaNGitG3+Y7OaiLZSm+lVVvaSUjENSqpESk7KRur1uOOOS1sZwwY9VjAzzSOipZZNZ1NKkO6WjurSdcNCakjfSL9KZ0rdOnab//gcEW1zGWk905L6fOP4TIVxfDaJOeqoo9I2Jc00OSlc5ZCI1jdXXHFF2q49KUj9YdWwYeFzS8n3SUZWm7v22mvTdi8ph0S01dz0vdS9ko2SkuOTMlZOMcXYVFL9rvxm2mY3HVSpRApdSVCK2v0jrT8eMIXOM6JPMnLeVq5cmbbP1G3y4z5zrUvdS4v7vI5P2t/z1vndeeed01bGMHXT5kndyo9KQO4n7y1N7ffdS8PCtDl90Ud/m9rr2e166UqKnkH+rkndm1LoPZRSfW7nWelUedY5dz04PuW97r39N+Uix6HU5N+OFMUQFAqFQqFQqBeCQqFQKBQKEas91A0n78HcuXPTlpo0WlfqwihNqWjpxW7TiZkzZ6ZtZTejlaU/pZHsNy6NasVDqTqjU82gkM41mtUGLt1/MxrZOZBuPeigg9KWyrEh0Fhx4oknpt3XVMh58zv6Rl92fWOFOyNzpZqlP6U8zQaxWqQZDjZ8stKaGRRSiFJ0SjsRbSS2DUWk06R999tvv4Hf2XPPPWMYHHnkkWk750aWS5H7fK5h/ejei2jXj5Sk0ofrTUrWrBLpbjMcnH8zilzz0q6um25FTp/DKnQeQffcc0/au+66a9r65fDDD49h4XpWIlGqdN6MHHfP9FXTi4jYaqut0jZzQzrbqHXnoa/Jk7KqEprShdS3+7CvcmJEu/eVBJ0bm//YMMtnOPnkk2MYeCa4H5xbswaUKT1brOLYlYGUQ31W96KUvHKrFXGVGs3oevDBB9NWBvJzfa0MpB8jWvnBvevvqN9xfSjxKEGuCsUQFAqFQqFQqBeCQqFQKBQKo8gysHiLlLHFIKQzpW9sEiL9NX/+/OYe0pxmDdi7XjpZ+uawww5L20ISV199ddo2OpLalGKWBjILQlosoi3wYpS0NKeRyX2FMsZDMtA3FrqxSZB0oXNoMRypxgULFjT36PON1Ka0s8+rXGKhqOuuuy5tm930+cb5lBrrNgoxgt3CSdOnTx/4PI6pW8RkGEhzeg8zMPyOzaUs7OWavPzyy5t7SEkqofU1R1JCURLxc/eb2RwWX1GyMRJdqaPrFzMfXJtTp05NWzq9r0DSeEBp0Ps410aC65v1118/bSlri2V1/0261/PTtaCUtPXWW6et1GIGwIQJE9LWN+4fx+Be6M6nkp1R754PUuqOaTx9I72vfKPk6Xcck1KO53U3o8JreQ+pd+fNParc4N8q35hd4tms7bhdc0pTEa2U4/nnb45nlhJIt1DWSFAMQaFQKBQKhXohKBQKhUKhMArJQPpB6t2CPX29DIxUNjpZmjiijUbffPPN01Z+kEKR+pbitta4kZ1STUbHK3tIuRgJ2u3tLtV/1llnpW1RECNGHUe3XvWwMCLWbIg777wzbaO3pRrNsLC3gNkZERFLlixJWxpZmUB6X9qsrzCVkcBSuFJoI/GNElFEKw3YE8N1JFXrOCy4MizcM5tuumna0rNKPBZ/UmKQjlR+i2ijh/WrGQFS0c6nRZucGylnMwiUJyzuohxl/X7XU0RLg1944YVpKyOaUaS/pXnHA1Kz0sDOu4VuPGuUV5TZupH79ibw31x70s6eSX2Fm5QGjE53fpTDlD30jedoRJups3Tp0rSltr2uMstYCuD0wT3jGvZsUa5S3vA7nrfdwkSe5e4Hz2XpfefWvau//L5+UTbzOp53zrFrPqJdm8qzynRe16wQrztSFENQKBQKhUKhXggKhUKhUCiMQjI444wz0rbghpTXvvvum7YRm9Kw0ovdwi9SVdI3ShFGjyoNTJs2Le2777477X322SdtI3Qt2GF7ZiPfpfy7tJPP57X6ooal7vzOeOCcc85JW8lA6nfGjBlp97V+lvbabbfdmnvYU8DiS8oS0mNSpEaRWxxK/1uIymh5C5VIozqH0tcRbQS0vpGOlnbUN91rDYOLL7447c022yxtad8ddtghbSlPxy0Fud122zX3MLJd6tA69/pFicd1r3TkXrIYjWOyRrvUuJHQ2hHtflCK8HueJ9Kf3WsNC/ujmDVgASklTSlozyCfoyuzuTcsrCRlLUVu5oU9IZSM9JmSjHvGc9H9ZgZHtzeEa1I623PYs9s1qZQ0LMzAUDJQ7lV21F/uH+e122PC+TRDwnNRSdLnU5ZzTdgnwnWuX3wepSYll6784nN7Til1uE+8XzdjYSQohqBQKBQKhUK9EBQKhUKhUBhFL4NCoVAoFAr/76IYgkKhUCgUCvVCUCgUCoVCoV4ICoVCoVAoRL0QFAqFQqFQiHohKBQKhUKhEPVCUCgUCoVCIeqFoFAoFAqFQtQLQaFQKBQKhagXgkKhUCgUChHxPyROwaPX8B1KAAAAAElFTkSuQmCC\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Define the RNN goodies"
],
"metadata": {
"id": "-CeUWnzpEoqb"
}
},
{
"cell_type": "code",
"source": [
"class ImageEncoder(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)\n",
" self.bn1 = nn.BatchNorm2d(32)\n",
" self.relu1 = nn.ReLU()\n",
"\n",
" self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)\n",
" self.bn2 = nn.BatchNorm2d(64)\n",
" self.relu2 = nn.ReLU()\n",
"\n",
" self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)\n",
" self.bn3 = nn.BatchNorm2d(128)\n",
" self.relu3 = nn.ReLU()\n",
"\n",
" # Global Average Pooling to reduce spatial dimensions to 1x1\n",
" self.gap = nn.AdaptiveAvgPool2d((1, 1))\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu1(x)\n",
"\n",
" x = self.conv2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu2(x)\n",
"\n",
" x = self.conv3(x)\n",
" x = self.bn3(x)\n",
" x = self.relu3(x)\n",
"\n",
" x = self.gap(x)\n",
" return x"
],
"metadata": {
"id": "z81MMxklEqhO"
},
"execution_count": 29,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class ImageDecoder(nn.Module):\n",
" def __init__(self, out_channels, initial_height, initial_width):\n",
" super(ImageDecoder, self).__init__()\n",
" self.conv_transpose1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn1 = nn.BatchNorm2d(64)\n",
" self.relu1 = nn.ReLU()\n",
"\n",
" self.conv_transpose2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn2 = nn.BatchNorm2d(32)\n",
" self.relu2 = nn.ReLU()\n",
"\n",
" self.conv_transpose3 = nn.ConvTranspose2d(32, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn3 = nn.BatchNorm2d(out_channels)\n",
" self.relu3 = nn.ReLU()\n",
"\n",
" # Additional layer to ensure correct output dimensions\n",
" # This layer is only needed if the initial size cannot be exactly achieved through the strides and paddings chosen\n",
" self.final_resize = nn.AdaptiveAvgPool2d((initial_height, initial_width))\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_transpose1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu1(x)\n",
"\n",
" x = self.conv_transpose2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu2(x)\n",
"\n",
" x = self.conv_transpose3(x)\n",
" x = self.bn3(x)\n",
" x = self.relu3(x)\n",
"\n",
" x = self.final_resize(x) # Ensure the output has the same HxW dimensions as the original input\n",
" return x"
],
"metadata": {
"id": "gzQ35KxeE2Cz"
},
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class CustomRecurrence(nn.Module):\n",
" def __init__(self, in_channels, initial_height, initial_width, hidden_dim, num_layers=2, training=False):\n",
" super().__init__()\n",
" self.image_encoder = ImageEncoder(in_channels=in_channels)\n",
" self.positional_encoder = nn.Embedding(T, 8)\n",
" self.rnn = nn.LSTM(\n",
" input_size=128+8, # hardcoded for the time being\n",
" hidden_size=hidden_dim,\n",
" num_layers=num_layers,\n",
" batch_first=True\n",
" )\n",
" self.image_decoder = ImageDecoder(\n",
" out_channels=in_channels,\n",
" initial_height=initial_height,\n",
" initial_width=initial_width\n",
" )\n",
" self.training=training\n",
"\n",
" def forward(self, x, hidden_states=None):\n",
" batch_size, timesteps, channels, height, width = x.shape\n",
"\n",
" x = x.reshape(batch_size * timesteps, channels, height, width) # (b*t, c, h, w)\n",
"\n",
" latent_vectors = self.image_encoder(x) # (b*t, hidden_dim, 1, 1)\n",
" latent_vectors = latent_vectors.reshape(batch_size, timesteps, -1) # (b, t, 128)\n",
"\n",
" pos = torch.arange(timesteps).unsqueeze(0).repeat(batch_size, 1).to(device) # (b, t)\n",
" pos_embeds = self.positional_encoder(pos) # (b, t, 8)\n",
" latent_vectors = torch.cat([latent_vectors, pos_embeds], dim=-1) # (b, t, 128+8)\n",
"\n",
" if self.training:\n",
" rnn_outputs, _ = self.rnn(latent_vectors) # (b, t, hidden_dim)\n",
" else:\n",
" rnn_outputs, hidden_states = self.rnn(latent_vectors, hidden_states) # (b, t, hidden_dim)\n",
"\n",
" rnn_outputs = rnn_outputs.reshape(batch_size * timesteps, 128, 1, 1) # (b*t, hidden_dim, 1, 1)\n",
" reconstructed_x = self.image_decoder(rnn_outputs)\n",
" reconstructed_x = reconstructed_x.reshape(batch_size, timesteps, channels, height, width)\n",
"\n",
" if self.training:\n",
" return reconstructed_x\n",
" else:\n",
" return reconstructed_x, hidden_states"
],
"metadata": {
"id": "chqbXzZaE5Ok"
},
"execution_count": 34,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Training Loop"
],
"metadata": {
"id": "KD_os31TMhEN"
}
},
{
"cell_type": "code",
"source": [
"model = CustomRecurrence(\n",
" in_channels=C,\n",
" initial_height=H,\n",
" initial_width=W,\n",
" hidden_dim=hidden_dim,\n",
" training=True, # important parameter\n",
")\n",
"model.to(device)\n",
"\n",
"optimizer = torch.optim.Adam(\n",
" model.parameters(),\n",
" lr=1e-3\n",
")\n",
"\n",
"loss_fn = nn.MSELoss()"
],
"metadata": {
"id": "TYo28dSj49K_"
},
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for epoch in range(epochs):\n",
" for step, batch in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
"\n",
" batch_size = batch[\"pixel_values\"].shape[0]\n",
" batch = batch[\"pixel_values\"].to(device)\n",
"\n",
" t = torch.arange(0, T).flip((-1,)).repeat(batch_size, 1).to(device)\n",
"\n",
" input_images = q_sample(\n",
" x_start=batch.unsqueeze(1), # (B, 1, C, H, W)\n",
" t=t, # (B, t)\n",
" sqrt_alphas_cumprod=sqrt_alphas_cumprod.repeat(batch_size, 1),\n",
" sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod.repeat(batch_size, 1),\n",
" )\n",
"\n",
" reconstructed_images = model(input_images[:, :T-1, ...])\n",
" loss = loss_fn(reconstructed_images, input_images[:, 1:, ...])\n",
"\n",
"\n",
" if step % 100 == 0:\n",
" print(\"Loss:\", loss.item())\n",
"\n",
" loss.backward()\n",
" optimizer.step()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ImaLvHwjMiDf",
"outputId": "f0a12e6c-775d-4fad-8ac3-c15935f6ba5f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Loss: 1.4094734191894531\n",
"Loss: 0.7404128909111023\n",
"Loss: 0.7145070433616638\n",
"Loss: 0.7277837991714478\n",
"Loss: 0.7170003652572632\n",
"Loss: 0.7365358471870422\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Inference Loop"
],
"metadata": {
"id": "PP6CFN4vMidh"
}
},
{
"cell_type": "code",
"source": [
"model.training=False\n",
"model = model.eval()"
],
"metadata": {
"id": "iNwii9gJ-Utu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"with torch.no_grad():\n",
" current_input = torch.randn(\n",
" 1, 1, C, H, W\n",
" ).to(device)\n",
" hidden_state = None # Start with no hidden state\n",
"\n",
" for _ in range(timesteps):\n",
" # Forward pass\n",
" output, hidden_state = model(current_input, hidden_state)\n",
"\n",
" # Use output as next input\n",
" current_input = output"
],
"metadata": {
"id": "8NxLub4ZMjpH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"output.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rNmgRa7m-ZEW",
"outputId": "8cafec2a-40ec-4e8f-ca55-94d6e78512ea"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([1, 1, 1, 28, 28])"
]
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"source": [
"plt.imshow(output[0, 0, 0].detach().cpu().numpy(), cmap=\"gray\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 406
},
"id": "OkJeUd45JooW",
"outputId": "be0a63f0-2584-43f2-bee7-a5fb0aa0bd2d"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAHXUlEQVR4nO3cMY6NfR/HYUcmkZEggw6NiERiDyqFBViDBViBxjroprQAQUGrkikUFoDMIGGGDOftPnkLuTPHc875z/O4rvr+Jd8pZj7uwj2bz+fzEwBw4sSJk6MHAHB8iAIAEQUAIgoARBQAiCgAEFEAIKIAQDaO+uBsNlvlDhju6tWroydM2tvbW8vNumxtbY2eMGl/f3/hm4ODgxUsWZ6j/F9lbwoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD0Ajou9vb3REybt7++PnrBUx/3nOTw8HD1hCG8KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgs/l8Pj/Sg7PZqrcAsEJH+XPvTQGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4vSdPnoyeMGl7e3stN+v0+fPn0RMm7e7uLnyzt7e3giXL8fDhw9ETJu3s7Kzl5rjxpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4ve3t7dETJr1582bhm5Mnj/e/QR49ejR6wqRPnz6t5WZddnZ2Rk+Y9OHDh9EThjjev6UArJUoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAZvP5fH6kB2ezVW/hX+RPPm533D+Id/PmzdETJn358mUtN+vy8ePH0RP+Okf5c3+8f0sBWCtRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAZGP0AH7v2rVroydM2t3dXcvNOt25c2f0hEnfvn1b+Obr168rWLIcV65cGT1h0vPnzxe+efHixfKHrJk3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgGwc9cFTp06tcsc/dnh4uPDNz58/V7BkOXZ3d0dPmLS/vz96wtK9evVq9IRJv379Gj1hqc6dOzd6wqSDg4PRE4bwpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFAPKf+SDen/BBPP7fy5cvR0+YtLm5uZabdblx48boCZN+/PgxesIQ3hQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD1gWW7fvr3wza1bt1awZDkeP348esKk9+/fr+Vmne7duzd6wqTz588vfLO1tbWCJctx//790RMmPXjwYOGb169fr2DJenlTACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAyMZRH/z+/fsqd/xjb9++XfjmOP9MFy9eHD1h0oULFxa+uX79+gqWLM/ly5dHT5h09uzZhW/OnDmzgiXL8ezZs9ETJr179270hCG8KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgPzVH8T7k5t1uXv37ugJk/7kg3h/crNOly5dGj1h0unTpxe+2dzcXMGS5Xj69OnoCZN8EA+Av54oABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAMpvP5/PRIwA4HrwpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQ/wHzssF0wLHC6wAAAABJRU5ErkJggg==\n"
},
"metadata": {}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment