Created
April 16, 2024 20:00
-
-
Save ariG23498/2f02dbd3a6417b66570c56fdaa3641aa to your computer and use it in GitHub Desktop.
autoregressive-diffusion-lstm.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"authorship_tag": "ABX9TyM5ZVCOxhE+W71bDZtd3Ii+", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/2f02dbd3a6417b66570c56fdaa3641aa/autoregressive-diffusion-lstm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"This notebook is heavily inspired from: https://huggingface.co/blog/annotated-diffusion" | |
], | |
"metadata": { | |
"id": "02vydzVb6VWO" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup and Imports" | |
], | |
"metadata": { | |
"id": "-CFKktKV2qqo" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --upgrade -qq datasets" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "BmiGcVlQ64dD", | |
"outputId": "b5078cae-cbd3-4b61-ad7b-6eb403a6e205" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"id": "hzUHzDmy2kOy" | |
}, | |
"outputs": [], | |
"source": [ | |
"import random\n", | |
"import numpy as np\n", | |
"from matplotlib import pyplot as plt\n", | |
"\n", | |
"from datasets import load_dataset\n", | |
"\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch.nn import functional as F\n", | |
"from torch.utils.data import DataLoader\n", | |
"\n", | |
"from torchvision import transforms" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Configurations" | |
], | |
"metadata": { | |
"id": "rlNouEPu7ART" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"B = 128\n", | |
"H = W = 28\n", | |
"C = 1\n", | |
"T = 100\n", | |
"\n", | |
"hidden_dim = 128\n", | |
"epochs = 5\n", | |
"\n", | |
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | |
"print(f\"{device=}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "3HNObPXp7Bff", | |
"outputId": "f8a7e55c-ce8d-4f62-8265-94d0d16cf895" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"device='cuda'\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Dataset and Loaders" | |
], | |
"metadata": { | |
"id": "JGEzhe8s7B1p" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# load dataset from the hub\n", | |
"dataset = load_dataset(\"fashion_mnist\")" | |
], | |
"metadata": { | |
"id": "PdIqVYuj6axU" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# define image transformations\n", | |
"transform = transforms.Compose([\n", | |
" transforms.RandomHorizontalFlip(),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Lambda(lambda x: (x * 2) - 1)\n", | |
"])\n", | |
"\n", | |
"# define function\n", | |
"def transforms(examples):\n", | |
" examples[\"pixel_values\"] = [transform(image.convert(\"L\")) for image in examples[\"image\"]]\n", | |
" del examples[\"image\"]\n", | |
" return examples\n", | |
"\n", | |
"transformed_dataset = (\n", | |
" dataset\n", | |
" .with_transform(transforms)\n", | |
" .remove_columns(\"label\")\n", | |
")\n", | |
"\n", | |
"# create dataloader\n", | |
"dataloader = DataLoader(\n", | |
" transformed_dataset[\"train\"],\n", | |
" batch_size=B,\n", | |
" shuffle=True\n", | |
")" | |
], | |
"metadata": { | |
"id": "SbnfSuvR6eel" | |
}, | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Get a batch of data and check the shape\n", | |
"batch = next(iter(dataloader))\n", | |
"print(batch.keys())\n", | |
"print(batch[\"pixel_values\"].shape)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GneDUV5J74KJ", | |
"outputId": "a4d06196-84e2-439a-b314-1a7b7e38faf3" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"dict_keys(['pixel_values'])\n", | |
"torch.Size([128, 1, 28, 28])\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Forward Diffusion Process" | |
], | |
"metadata": { | |
"id": "Ggg6okjO8NFS" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"id": "9_7ipOKVIbFn" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Define a linear schedule\n", | |
"def linear_beta_schedule(T):\n", | |
" beta_start = 0.0001\n", | |
" beta_end = 0.02\n", | |
" return torch.linspace(beta_start, beta_end, T)\n", | |
"\n", | |
"# define beta schedule\n", | |
"betas = linear_beta_schedule(T=T)\n", | |
"\n", | |
"# define alphas = 1 - beta\n", | |
"alphas = 1.0 - betas\n", | |
"alphas_cumprod = torch.cumprod(alphas, axis=0)\n", | |
"sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)\n", | |
"sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)\n", | |
"\n", | |
"# sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).repeat(batch_size, 1)\n", | |
"# sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).repeat(batch_size, 1)\n", | |
"\n", | |
"def extract(a, t, x_shape):\n", | |
" batch_size = t.shape[0]\n", | |
" out = a.gather(dim=-1, index=t.cpu())\n", | |
" return out.reshape(\n", | |
" batch_size, T, *((1,) * (len(x_shape) - 2))\n", | |
" ).to(t.device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# forward diffusion\n", | |
"def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None):\n", | |
" if noise is None:\n", | |
" noise = torch.randn_like(x_start)\n", | |
"\n", | |
" sqrt_alphas_cumprod_t = extract(\n", | |
" sqrt_alphas_cumprod,\n", | |
" t,\n", | |
" x_start.shape\n", | |
" )\n", | |
" sqrt_one_minus_alphas_cumprod_t = extract(\n", | |
" sqrt_one_minus_alphas_cumprod,\n", | |
" t,\n", | |
" x_start.shape\n", | |
" )\n", | |
"\n", | |
" return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise" | |
], | |
"metadata": { | |
"id": "pmHZGKa8Lhbz" | |
}, | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# create a batch of noise images\n", | |
"t = torch.arange(0, T).flip((-1,)).repeat(B, 1)\n", | |
"input_images = q_sample(\n", | |
" x_start=batch[\"pixel_values\"].unsqueeze(1), # (B, 1, C, H, W)\n", | |
" t=t, # (B, timesteps)\n", | |
" sqrt_alphas_cumprod=sqrt_alphas_cumprod.repeat(B, 1),\n", | |
" sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod.repeat(B, 1),\n", | |
")" | |
], | |
"metadata": { | |
"id": "uWgxpZPj8sY6" | |
}, | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Viz the forward process" | |
], | |
"metadata": { | |
"id": "aQ0H8DVxEJqF" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"idx = random.randint(0, B-1)\n", | |
"for i in range(5):\n", | |
" plt.subplot(1, 5, i+1)\n", | |
" plt.imshow(input_images[idx, T//5 * i].permute(1, 2, 0), cmap=\"gray\")\n", | |
" plt.axis(\"off\")\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 122 | |
}, | |
"id": "I-hoBcj3DNSf", | |
"outputId": "f4dd8b4d-4dff-4a2d-c8b9-11594daa2318" | |
}, | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 640x480 with 5 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAABpCAYAAABF9zs7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApZ0lEQVR4nO3dd7RdVdU28An2XrA3RGxYAQkigpAYIqIYSkIXCEgAKaFEAyIIgg5EIEggFEMIhBokEppUJZhIE8Gu2LvYe1fev97pb+1xdsa999wx3jG+bz5/zXFy7t5rr7nWOjvPM8tqDz300ENRKBQKhULh/2us/n89gEKhUCgUCv/3qBeCQqFQKBQK9UJQKBQKhUKhXggKhUKhUChEvRAUCoVCoVCIeiEoFAqFQqEQ9UJQKBQKhUIh6oWgUCgUCoVCRDx8pF984xvfmPb222+f9ic/+cm0995777SXLl2a9tvf/va0f/vb36b9y1/+srnHM57xjLTvvPPOgZ//9Kc/TXujjTZK+/DDD0/7gQceSPuCCy5I+1Of+lTakydPTvtFL3pR2r/5zW/S/uc//5n2v//972as//nPf9L+9a9/nfZPfvKTtLfbbru0b7jhhrTXXHPNtE877bQYFs7D1KlT07766qvT3mOPPQZ+/ta3vjXt3/3ud2l3ffP0pz897XvuuSftPt9suOGGac+ZMyft73znO2kvXLgwbX0zceLEtF/4whem7dr517/+NdCOiLDWln+jb97xjnekffPNN6f9/Oc/P+358+fHMFhvvfXSnjJlStq33HJL2tOmTRs4js033zztP/zhD2m71iIi1lhjjbS/+MUvDvz8wQcfTHvddddNW7/ouwULFqR90003pb3xxhun/bznPS9t1437ZFV7xmf6+c9/nvYWW2yR9ooVK9J+znOek7Z7eqxYZ5110n7DG96Q9h133DFwLH4+YcKEtP/85z+n/fvf/765x5Oe9KS0PZP8XH++/OUvT/s973nPwOvqm8985jNpv/a1r037mc98Ztp/+tOf0l6Vb9wz/s2vfvWrtJ2nL3zhC2l7Bni2jAXPfe5z037JS16S9re//e20X/3qV6f9rW99K23P8b/+9a9p/+1vf2vu8djHPjZt1/3jHve4tP/4xz+m7dqbPXt22quv/t//T3/84x9P+3Of+1zanl9PfOITB47Pue/6Rfgcju+lL31p2t/73vcG3u/ee+/tva4ohqBQKBQKhUK9EBQKhUKhUIhYbaS9DN73vvcN/PzJT35y2k94whMG2lKk0h7rr79+c63bb7897Z122intRzziEWmfeuqpaZ9wwglpS7UsWbIk7be97W1pS7N8//vfT/vxj3/8wOeR7rnxxhubsX79619PW8psv/32S9vn/sUvfpH2ox71qLQ/8pGPxLCQXhRSkz6j9qc//em0//GPf6QtBRkRsXLlyrSnT5+etr6ZN29e2h/60IcGXlffSJc7P9/97nfTlt7TN8out956azPWb37zm2krdShpSbf2+Wbu3LkxDPbdd9+0V1tttbTdGz6flKUU9d///ve0X/GKVzT3+PznP5+2a/3hD/+vGrho0aK0jz322IHX1S+777572lKQyj2OW2pSelXKP6L161Of+tS03es+t/Kd60zafKxwDQt98OhHPzrtxzzmMWkrzSgrSllHRHz1q19N+01velPa+kbJ9aijjkpbSvnKK69M2/PFcUidO1afx3Pqvvvua8b64x//OG3PjS233HLg/ZSJ9I1S8Vig/Cncl4985CMH3vuHP/xh2sqISigR7bO+7GUvG3itu+++O+1DDjkkbX9Dli1blra+87fiG9/4xsBn0Hb/+LsU0e4B95zn8w9+8IO0//KXv6StpKHMuyoUQ1AoFAqFQqFeCAqFQqFQKIxCMrj00kvT/trXvpa2lMvll1+e9pFHHpm29Nd1112XthRWREtlS7d7D+ksqRajeI2Qlz6T0nvBC16Qts9mJLTR2VJnES3davTt/fffn7aSw89+9rO0jZhdvHhxDAspYaUMo3SlHY2U1TdSXVLtES31eNtttw28h5S+NJ1zYtSyvnnlK1+Z9lOe8pS0P/GJT6QtDaiPjQaPaGUpKcgvf/nLaet/fSNF7roYC0455ZS09flaa62VttkVM2fOTNtnUEpzbiL6ZQbvoV/WXnvttJWBjHaXmtxggw1iENzHHiFeZ5NNNmn+RunIbAclHiUHfSwdLFU7VrgHvLZR7koeylvufaO3PZsiWsnhS1/6Uto+o9KJc2JWklkG+sZsHDM1li9fHoMgzW8GTES7tzxvpbA9E8zecf94NowFZqS5llz3rpfXv/71afsMylOeMxGtNPCjH/0obffJ0572tLTNxDr//PPTdo/q6x122CHtz372s2l7NrtnpPmVQiPa36O+veHvoFkvyg3ee1UohqBQKBQKhUK9EBQKhUKhUBiFZCDFZvEWI08nTZqUtlSTVNq1116bdvfWFj+SvrnrrrvSlr6RtpIOk0LpozCNLJeGs3iN9HY3SlMKS9pdKtpIUiNbpW2PP/74GBazZs1KW8pTinyzzTZLW4rUuZKmlKqKaKl3n90o96233jptfaP00ucbaX/HZ0SxtKi+sUhKRLt2pPWcG32jz5WVhs0A2WuvvdJ+1rOelbaRx86rRVKUz6Rhu3vGTB2zF/T9Nttsk7b7xOhk76fv9KnjNlLevSRN6Rgi2nNA2t25kc71up45Z511VgwL50R6WKrZSG4pWs8Lz4Wubyx+pG+kvM120DfKst7Pgk5mZ/SNQ38oPShhRbT+1zfOjXvR+Xj2s5+dttkqY4FFn5wz14JnrPPh+PRjF54DygmeA7vuumvaFmpSAvAM8Tdg//33T/v6668f+B2pfW3PuIhWilAS0V/uRcfqOaicuCoUQ1AoFAqFQqFeCAqFQqFQKIyil4H1o4UUqzWVlRKkjK3vb1GOiIgrrrgibQuwGBErrSMF5veNspVOcRzSzFJNUlN+3o3uthCSkaHWI7eQizSedPd4wGe3AI6+MRtAKtfntb7/Nddc09xDX0mFvuY1r0lbukuq0nEYuW3hDP/2K1/5StpKF86b1LLUWERLrVlzv6/mt+j2ChgG3k9IeZoNICXvsyqnmH0T0co8L37xi9N2HfpM+tu5UfrRL+4ZpRypWqlu91s36r4vYto58B5+XxljPOCedc84LqllJVDXdl8fhIi2pr0+N/vIQjfuAbN0+iLE9Y1UsWekftKvyhAR7Xlm9pdz0KWzB91jWChRCM9fMzOULqTe3XsWbYroL1zmvT2vfW7XtNkVrlX3mOP298CMG/9WiaA7DuVW58D9J5QoRopiCAqFQqFQKNQLQaFQKBQKhVFIBtJk0h3Sweedd97Az6VNpEcOOOCA5h5S9NYrlxKRnpL2NRPhYQ97WNoWvJEqNwLcngjKExYT6Y7VbAnnQ1pOGk8KsEsLDQvpVGurS4vaMravaI2+sUhOREvRX3jhhWnrD8dhoSgL4OgbaW0lJn1zzDHHpG3Ut7JCd6wW++nzjfSitOF4+sZ12ydjWHjJSHrpQf2yyy67NPfQLxafcu2ZvWD0upKbVL9ZMBaV2nPPPdO2t4kSkrJHd6xKea5TKXuLtOgXI7rHA1L1fXS57ajNhLAYlJSuvSQi2jm134ZrT8pb3/i513Ht6Bvbm7/3ve9NW18aeb/VVls1YzVLwfnQN1LQnumOb1iYFeSe6cugcZ+YMaBfLPgU0Z5B9pvw3p5rUv1m5ngdKXzlaP1i3xgzM/rak0e0Puvzi3tJKWgsfimGoFAoFAqFQr0QFAqFQqFQGIVkcO6556YtfW5hHaloa/8bOWxU7vvf//7mHkbLS4etscYaaUvdSSNJoRhhauSvhY+kA7fddtuBn1srXGoqIuLOO+9MWyrbGu9Gm1oYyHGMB6yvLX1+0kknpa18YNS6tbw33HDDtD/4wQ8295DeV0pRclAmUv6R6jISXpnHbAUL8VgYx5bFtpN1HUS09KfZMUbkG3mvP1yfw8KsGelzC+tIAVs4ywhmacSPfexjzT2khI1GNxpaKcI69VLARr6799yTrvnJkyen7R6z0FLXL17X9aSUoOxkH4Vu7f1h4T632Jk+U+azAJR0vvN58cUXN/dw3qV+lRz0kz5wXxnpryQjPayE5j6WXncvdLMMjMR3HO4lZRbXRbcl9zDwbPFc9azwLO7L2FhVG25/T/SlZ5byg99XVvBzMzuU69zHZnopPbhOHENEK/f1FdBSZvNazsFIUQxBoVAoFAqFeiEoFAqFQqEwil4GRx99dNpGbkvPW8xGuk3qUHpDSiiizWSwJrZRzEbHWs9Zim3u3LlpG1kuHXniiSemffLJJ6ctRW1rTT+PaGtdKw0YDa2M8apXvSptJZRjjz02hsWcOXPSNuJUet6a2kZSO4fSYY43op07ixRJmVof3ghX73HmmWem7fwYkX/aaaelrSQltSudvOOOOzZjtf2oBX6k47y368ueEx/+8IdjGLz73e9O2z0jzWnksVH1UrrWlne+I9pMAdtXK984H17L+bCFtnvG659xxhlpm/2x+eabp+2e9vkj2ta27nUzJfSL69fiMI5jrHDNuGekWaXLPV+kk/v8GtFS2BYpUibafffd0/YcUWYzo0mYreC+8jxQPpCy7q5tJTgj7N3H+sZCS9Lul1122cCxjhRKdmYZGMVv1L/yi+vZv+0WO/KcUyoxc0JZXNnFc6rbQ2XQ9c2WO/DAA9N2bZtx0C3Wp7Tmdc1wED6re7d6GRQKhUKhUBgx6oWgUCgUCoVCvRAUCoVCoVAYRdqhKRamN5m6pK5rWoS2mlZXb1L7Mo1Dzc40HTUVq9OpO1999dVp28tabUxN3XS7WbNmpT179uxmrOq9pu6ZHqP25FinTZsW4wn1NRt36JvXve51aesPm7boP9OvIiIOP/zwtE2F0TdqcGrEpniZampcgxqrWqXzbNyH+vTBBx/cjFUdzZgAG/hYFU5d3TiIYaHubP9572e6pbEljtt0MecsIuJd73pX2ursVlzUL17XNELXunEN9957b9pq7a65+fPnp21lthkzZjRj9W9MxzLVTE1XfX7KlCkxnjCuyf3rPdXcnVsry/kdz5eItpGbaa7GGfXNifr0zjvvnLbnS9857Lq7/PLL07Z5WXedm8pn3I3nifPkWth4441jvOB54px5PzV3Y5vU901P9xki2t8gUwT9TfD3zkqhnp2mK7u//e3yGYx3UNP33J0wYUIzVufDmBTXiv4y9sS1OVIUQ1AoFAqFQqFeCAqFQqFQKIxCMjAlxvQOKT4bCS1btixtKXKpOtOVIiL222+/tE2LkbIxBUdqTBqvr6KT6SpSr6bv+H1pOCtARbSUz3333Ze2zTJMkZQClO6ZNGlSDIu+alvSWFYTU0axSqO+2XTTTZt7SCl/4AMfSNusVSsBSr+5dqTibE6lrcSgb5Q3lGBM44qIuPvuu9O2Op4Uq5U09Z++6TaAGS2kC/WLzyGtd8stt6QtRW7apnRnRJu2euihh6atX6SypR0dn7KC1LC2so57QxlCv5h6HNE2sHKfuC9Nf/M7riGbLI0V0sumcCmHmVqnvCJFbnpoN1XXipS77bbbwHE4J549jk8633lQxlCeVJLyXPRz92FES6vrD/2v7Oh3lAe71WdHC/d7X8q4MqK/P+4lZZnu2e350lc11nPA9e26l843xU/fOf/KHn7fz/1Ni2jXo75371oh2Psph4wUxRAUCoVCoVCoF4JCoVAoFAqjkAykXqWWpdWM4pa+6aNTulG59nO//fbb0+6rWOV1zT649NJL07bymP3KHYeVuaT2pOqswtgdh5Gh9iK3171Vz7rVqIaFjU2kMPWNkoVVsqSQnZNuZauLLroo7ZUrV6atb6SdjbR1Hs1eMFpYitxqlo7P6nKuQTNMuuMwgvewww5LWx84N66dYdHXiMbn1ndWupM6lfoz6j8i4vTTT0+7r6d9nxygFGFTLrNupPDN/nBMVvH0mjajimjpUCsa7rPPPmkrOzk38+bNi/GEso0ZGdLRSm76TNrYeVCSiog44ogj0vZZlEBtXGS0uBHpnoWeO1L4nn/uH5uquQZdKxGt36w+a6aEkpFUvc85LPqqvrrfpfNd21LvZot4FkW02RbKAfpFOUBpxqqa+lt/uY/9HXRMViBUTlQ2imjXpmtQ6fCBBx4YONapU6fGaFEMQaFQKBQKhXohKBQKhUKhMArJYIsttkhbqt6iItIYNqKROrR5hRHJEW1EpRGx0tJSetKLXkv6xuh4aS7pNmkgi8M4BumoiJaik3Jec80103bObJTk/I0HzNZQ/pCmlDp3XMo2Sh/SWxEtFeW8SLMZ7b906dKB35eKs7iQUctmOCgZSHd7ze5YlSiM4jba2OwO52Y8fWM0us9hRLdFh6TnzZSQXve5I1qKtTsP/wulo8WLFw/8vtSkDXfs2+7+lk7v84v7NqKlpaVbpUKN+rZok/M3HvDaSgB9GSDrr79+2u4rm011n9foe/9N6UV63qZNfl/q3AZRUueuI8+gkZypEa2cIG3tebvuuusOHLcSxbCQbvdMV06x6JBNvFyHfqf7rMo8/puZUeutt17aSg5Kcf5uuD7MiPB5HJPP4xi6Y9X3ZhAooehv91L3rBgJiiEoFAqFQqFQLwSFQqFQKBRGIRlYCGGnnXZK2x7sy5cvT9sa2lKhUlvd4jfnnntu2lIn73znO9O2Br0FGSZOnJi2kadXXXXVwOtIR0n/90WtdgvCiO233z5t68A7JmlG7+2YxgqpyenTp6d96623pq00IJ0vrW3xoo022qi5x8KFC9PWN2ZlSGFKY+l/59fMDdeUFK6+kerTN9363xblMaLYaG3XnvS1NJvR72OB1KFUr1kaRnu7l6Qgpa6VuiIilixZkrZ+cR24d5V1lLGcZwu3OH/6RcrS7AHHIC3fhbKVRX8ckzLSWOjPVcG15FqwSJXFfvSZ+8R17ucRrT+dF7MyLCompWyRI+fa7B/PF2l7JVP/VnlPyj+i3TPKus6BZ7fSaldOHQausXXWWSftm266KW0zA4ywN6PM4nTdwmXK3M6J60B63r2h7KjM1legT6lLudS/dQxS/l14DiiPOKa+vgYjRTEEhUKhUCgU6oWgUCgUCoVCxGoPyRWtArvsskvaRgsbmWmBBWkaI1ilsCy+0v0b6RUjLW+77ba0ja6XtpISmjlzZtpG8Vq0wYhpKX8pq24hDyO0HbcR4WeffXba1kuXPj7wwANjWOywww5pSxcaBSutK8XURy92KVqfUbpL/69YsSJtZQKpK32z9957p93nG6k4ZQ8pT7M8Itp1JY1oNsGZZ56ZthKO2TEWMhoL9LN7wDnTF9ruMe0ufD79bSS7BcOk5P2OkpYykGvY3g5Gors3lPHsRRDRFr6y+JGSj62UpdOlx21RPlZYsMfzRd+4l5UG9IfnlHZE28vB/a98oASqxOIeMyvFvaGM9+Y3vzltM0MsjqUkYTZFRCv3WfxI6fCcc85JW98qp/qdscAzV19oW2DO9aZf9Gm3p4bP58+f54ayhPdQ8nQOlRssbqb8YsaHa9txK/9GtL9Nrq9NNtkkbdeBRdrWXnvttO2TsioUQ1AoFAqFQqFeCAqFQqFQKIwiy6Cv9aU0l9SflI2R/lLUXdrKyGhbh/q50bennHJK2lLiRphecsklaUsHG61rMRqj7o1g7bY2lX4zo0Iq0r4G0rYWzhkPWDDENsJGu0pdSV9bw76vTndExFve8pa0pQWll63PfcIJJ6StP6TrFi1alPbkyZPTloI2Ql4fSFlbnCSipQjtJ2HRmwMOOCBto9xXRc+PFs6/UdnKHVJ/rgspPteaEkpES2tfeOGFac+YMSNtKc/jjz8+bYvOuH8WLFiQtvtBmliKuW98+jqiXR+uO4vA7LzzzmlbO99zYzzgmrHoltKA45Iqdr24tt17ERHbbLNN2q5DC3I5d8cdd1zaRpv77NLOnslGnSsZKG3aj6a7zi2wZDaOtLO0uHJTVyoZBu5dJSbn2bXq74yyl98x+y2inTcl6NmzZ6ft/Bx99NFpW+TL89Izq6/olQWOHKtyd9cvngnKS2Y+9GVUKOGPFMUQFAqFQqFQqBeCQqFQKBQKo5AMpKT6IiSl8KUapeGMIjYSPaKNIrc9rZS4VLb0kjX8pfqkMI3Gl76x6IjFfKR8bVMa0c6HtK0UltGfju/aa69N27rxY4URsRam0Ddz585NW99ITVoYSio7op2jZcuWpW0UshKLdLS0mXSaEbTbbrvtwGcw68M5t0iK94poqVuLxRjprcxjASGLyZihMhZIRRs9LAXsfnAu11hjjbSVAroFsqQqjTA2A0dpTqrRqGqzD5w/pSL9eP311w/8vnvGrJWIdi9KbUqZurY8Z6R258yZE8NCKUpqWvr7/PPPT1s5tK+dt+dUREsLS8ObfWIBHDMLHJNSi3OtDCh9bYtsswyUBYzaj2jPVZ/Pc0sZyzNHOXRYWMjH9ebz+dvgWe95Z3Et271HtPKdhaiUwdwbZkm5Js2I8Czzfquv/t//c991111p6yNlGQtmda9rdorP4BngmdNtpTwSFENQKBQKhUKhXggKhUKhUCiMQjKQtjKCU0pXmlPaxChPizZYnzqipW+ktS26IVUlbWzkrwVypEKloKSPDz300LSlEo3KldaJaClAqUGfW0pP6mg8KbaIlgrUN47RKGnH6HeUOKRoI9o2ws6p1LTRzdKiUlr6yWhfaWPpTAs32eq3b84jWmrOyFxrgUvPe299PiyM7pc+d3xG4rvGlGhOOumktG2LHNFSvc6n+0fJwB4CzqFrUglM6Je99torbfekGUjOd0RL9SovSb3aw8H6/I51PKAEaNaNtfuVovSl69mzo5uZ41wopUhzS+uaBSMl7Bo2Q8Uocuvh21tFmUa6Wxk3om1R73rRN8pvyrU+57BwP/h8fbX7tc2CUeryuSNamcZ5cL3ecMMNafs745r0vHM9eQYrAynDKhtaHMlzOqL1vVKqa9YWy64Dn3OkKIagUCgUCoVCvRAUCoVCoVAYhWRg9KNR2VIzygGXXXZZ2lK60r5eJ6KlcY2mnTdvXtpSNtJqRk9L/ZhNoG00phLDHnvskbaRplKcERFbbrll2lOmTElbOtFCSPvvv3/a3eyKYWFkrtHJkyZNStuWx0ZGW6BFSqpbU1s62nk/9dRT09Y3+k+6XFpcGtAIXyNrtXfdddeB9/I7EW2RI2vg2/PAAlIHH3xw2t3simHgGI3Wlho2GtpiPVKh7pOuX8za0C9m/LiOpSelLaVO+7JWlI2korfbbru0lV+6Uc7OrXvd6ypXHHPMMWl3W1wPC+VD58dMASPBldaUEixq1j3PPIec6z7fWOxHqlmZQN+475WSpPDdC8qkrpWI9rktACZlbd+M8847L+1u2+dhYCElbWlxf0OUCZxj/euZGNE+h1LuRz/60bSd277eLV7H3xZp+wsuuCBtz0RlI/vG+LsU0e5LM8I88+zFovTn+EaKYggKhUKhUCjUC0GhUCgUCoVRtD/uq78tPSKdIu0kDWIEptHWES0tbX1m6VaLM0hFS9cZKWwUsJSs15ESknqzRauSRPffvJZFLKStLBCi/NDNtBgLLKDjuPSNBZOkZfWNBUm6mRBGf5tNom+ki/WNtKrR4gcddFDaRjlLtTtu5RijnLu9IVw7+kZbStD2slLhZtCMBbYMd0z6xUI4zrF+kQp1bUe0vnRv+XyuSa/ld1yfZhCY+eMec687Jp+zW5fdCGrpYMenJKIvjOg2A2OssJWz89YnqVgQyuht12c32t7MLPeD69tndy14dnjdadOmpa18IGWt9OAcum+7/QccnxS0a9U+CmYiKPe5XsYCqXQj6R2H68j+CvpFKaCbZaAc5z2k7l0HFkXyDHHOHYdza7t3n8E5c/15r4g208J14/j8PfHccL92+wb1oRiCQqFQKBQK9UJQKBQKhUJhFFkG0iPSWUbSSm9IDRsla718i15EtHTJkUcembbtJy3+YaS447NFpTXrpVmk0swY6KtZ3qXQpa36Wl9KR0mnW9hnPCB1qG+M6Nc3Rg6bZWDtdosdRbQRykZ/H3HEEWlLvRtRLo1lDXszH4QRtBal2mCDDdKW4jR7IKKVd/Sb0dDKRPpmPKPZfW5pWCOHnVepP+dpyZIlaXdbPUuNGr1uVovR/j6f0drKgBZlcXxGSUuh9/Uv6BYT8t+UGcxC8gyQTu8rljRWSCPrGwvMSL1LuSqLKPl1o7r9+8WLF6dtO3f3rjKrdK9Uvy2hnSspaM9b51YJpJs1Jc3t2eh8uBb0jfcbFspY7kvPW2n7PvnGzJVuG25h/xwzX/qKOOkXJTHnzM89j13/Sp7Ovb9jEe1a08f+Nvm7pixhVsJIUQxBoVAoFAqFeiEoFAqFQqEwCsnACEyL90gLGpUt5WxE8ezZswf+bURLce+zzz5pn3766WnbLtiCExbaMGpTakxax4yGs88+O22LrDgeI2wjWrpVSPU6H15r2EjcLqTWjGxXGrCwihSvfztr1qyBfxvRUtUzZsxI27nbaaed0jazwAJE0vlKA9LaFklZsGBB2lOnTh04nq5vLKYjpa5vbHPttaQah4Vzu80226TtOjTTxghr/9aof/82on0m23tfcskladtXwqwQKUhpTveSvlPiU4qzMJf0arfIin1CpNPNjDHbxGjw7nocFtLOFvDyeY0Qty+JdLl+NRMion0ue0jYOtpeIBZ/k7KW6nefuM6dK9tgWwRL+r97fnWL9/wvpOqVeZSYxtJmtw/OreeU86F01Sc3KY05TxEtdW+fAv3t75e0v/tHicK9pATlvcwWcS8pC3QLj42kcJXP529f32/UqlAMQaFQKBQKhXohKBQKhUKhMArJQErRyEyjZHfccce0rTF9yy23pG2xDv82oqUUjS43wtco20MOOSRt6Syj++0nIE0jtWlEfF+Uc7ftrNGcUrUXXXRR2kaCSkVKP44HpBSdHyNWlUKk+KRozRiRAotofWMvBull6W/bFrsWjKiVRpVuNUvEyF994/x3swz8Nylde230ZYAYPT8svIfZDlLpZrhI/d1xxx1pW8e9W2TFqGezV5RRnB/lByUb15BFVtyjRtEr6/S14dXvES1Nbw+UZcuWpW3BI+fJ+40HpFyVI1zn0u0+l7KS1LkyWURLw0tN62fv7Vr1vHDPOA9SwvZXkO7u8420e0S7BzyfbrzxxrT1n9/3TB8W+t/zy+wD7+czeYZ49krhR7S0v5Kb9LznmvtKn7qG3KN+py+zQL8I119EK7crO3vOK+f7m1VZBoVCoVAoFMaEeiEoFAqFQqFQLwSFQqFQKBRGEUNgiqDpNGuttVba6p7qLvbkvvLKK9NW/4lotR7T2dRO1L5MrTLGwdQNm7aoT6k3WQlMncf0HVN2Itq0RXUe/950FfVFPx8PqNXqG3Ut9f2NN9447YkTJ6Z91VVXpd3teaVOtXDhwrRN2TPW47LLLku7rwmVKXRWQVM/tcqYc3vzzTen3fWNeq2+UWdVL1RT7OqNw8CUV9eYuqV6tLEs6riuwy7Ucp1zY2Scz2uuuSZt4wbUPY0xcc8YV2L6qPNqeqv6ekSbIqZfTJUy9U7d2O+PB9SITa1zjek/42bU8VesWJG2Ka4R7flm9UfXq3O6fPnytI0bcB6M6XDPGJfgWei6u/fee9PWZxHtczvXrpG+M2w8faOG7nMYE2B8jPEErnljM7oNg1zT999/f9rOifNpdVxTDW025G+O6YH6zrgB17y/aX7eHavz7OfC73g2jBTFEBQKhUKhUKgXgkKhUCgUCmOsVCjtK20ibHixaNGitG3+Y7OaiLZSm+lVVvaSUjENSqpESk7KRur1uOOOS1sZwwY9VjAzzSOipZZNZ1NKkO6WjurSdcNCakjfSL9KZ0rdOnab//gcEW1zGWk905L6fOP4TIVxfDaJOeqoo9I2Jc00OSlc5ZCI1jdXXHFF2q49KUj9YdWwYeFzS8n3SUZWm7v22mvTdi8ph0S01dz0vdS9ko2SkuOTMlZOMcXYVFL9rvxm2mY3HVSpRApdSVCK2v0jrT8eMIXOM6JPMnLeVq5cmbbP1G3y4z5zrUvdS4v7vI5P2t/z1vndeeed01bGMHXT5kndyo9KQO4n7y1N7ffdS8PCtDl90Ud/m9rr2e166UqKnkH+rkndm1LoPZRSfW7nWelUedY5dz04PuW97r39N+Uix6HU5N+OFMUQFAqFQqFQqBeCQqFQKBQKEas91A0n78HcuXPTlpo0WlfqwihNqWjpxW7TiZkzZ6ZtZTejlaU/pZHsNy6NasVDqTqjU82gkM41mtUGLt1/MxrZOZBuPeigg9KWyrEh0Fhx4oknpt3XVMh58zv6Rl92fWOFOyNzpZqlP6U8zQaxWqQZDjZ8stKaGRRSiFJ0SjsRbSS2DUWk06R999tvv4Hf2XPPPWMYHHnkkWk750aWS5H7fK5h/ejei2jXj5Sk0ofrTUrWrBLpbjMcnH8zilzz0q6um25FTp/DKnQeQffcc0/au+66a9r65fDDD49h4XpWIlGqdN6MHHfP9FXTi4jYaqut0jZzQzrbqHXnoa/Jk7KqEprShdS3+7CvcmJEu/eVBJ0bm//YMMtnOPnkk2MYeCa4H5xbswaUKT1brOLYlYGUQ31W96KUvHKrFXGVGs3oevDBB9NWBvJzfa0MpB8jWvnBvevvqN9xfSjxKEGuCsUQFAqFQqFQqBeCQqFQKBQKo8gysHiLlLHFIKQzpW9sEiL9NX/+/OYe0pxmDdi7XjpZ+uawww5L20ISV199ddo2OpLalGKWBjILQlosoi3wYpS0NKeRyX2FMsZDMtA3FrqxSZB0oXNoMRypxgULFjT36PON1Ka0s8+rXGKhqOuuuy5tm930+cb5lBrrNgoxgt3CSdOnTx/4PI6pW8RkGEhzeg8zMPyOzaUs7OWavPzyy5t7SEkqofU1R1JCURLxc/eb2RwWX1GyMRJdqaPrFzMfXJtTp05NWzq9r0DSeEBp0Ps410aC65v1118/bSlri2V1/0261/PTtaCUtPXWW6et1GIGwIQJE9LWN+4fx+Be6M6nkp1R754PUuqOaTx9I72vfKPk6Xcck1KO53U3o8JreQ+pd+fNParc4N8q35hd4tms7bhdc0pTEa2U4/nnb45nlhJIt1DWSFAMQaFQKBQKhXohKBQKhUKhMArJQPpB6t2CPX29DIxUNjpZmjiijUbffPPN01Z+kEKR+pbitta4kZ1STUbHK3tIuRgJ2u3tLtV/1llnpW1RECNGHUe3XvWwMCLWbIg777wzbaO3pRrNsLC3gNkZERFLlixJWxpZmUB6X9qsrzCVkcBSuFJoI/GNElFEKw3YE8N1JFXrOCy4MizcM5tuumna0rNKPBZ/UmKQjlR+i2ijh/WrGQFS0c6nRZucGylnMwiUJyzuohxl/X7XU0RLg1944YVpKyOaUaS/pXnHA1Kz0sDOu4VuPGuUV5TZupH79ibw31x70s6eSX2Fm5QGjE53fpTDlD30jedoRJups3Tp0rSltr2uMstYCuD0wT3jGvZsUa5S3vA7nrfdwkSe5e4Hz2XpfefWvau//L5+UTbzOp53zrFrPqJdm8qzynRe16wQrztSFENQKBQKhUKhXggKhUKhUCiMQjI444wz0rbghpTXvvvum7YRm9Kw0ovdwi9SVdI3ShFGjyoNTJs2Le2777477X322SdtI3Qt2GF7ZiPfpfy7tJPP57X6ooal7vzOeOCcc85JW8lA6nfGjBlp97V+lvbabbfdmnvYU8DiS8oS0mNSpEaRWxxK/1uIymh5C5VIozqH0tcRbQS0vpGOlnbUN91rDYOLL7447c022yxtad8ddtghbSlPxy0Fud122zX3MLJd6tA69/pFicd1r3TkXrIYjWOyRrvUuJHQ2hHtflCK8HueJ9Kf3WsNC/ujmDVgASklTSlozyCfoyuzuTcsrCRlLUVu5oU9IZSM9JmSjHvGc9H9ZgZHtzeEa1I623PYs9s1qZQ0LMzAUDJQ7lV21F/uH+e122PC+TRDwnNRSdLnU5ZzTdgnwnWuX3wepSYll6784nN7Til1uE+8XzdjYSQohqBQKBQKhUK9EBQKhUKhUBhFL4NCoVAoFAr/76IYgkKhUCgUCvVCUCgUCoVCoV4ICoVCoVAoRL0QFAqFQqFQiHohKBQKhUKhEPVCUCgUCoVCIeqFoFAoFAqFQtQLQaFQKBQKhagXgkKhUCgUChHxPyROwaPX8B1KAAAAAElFTkSuQmCC\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Define the RNN goodies" | |
], | |
"metadata": { | |
"id": "-CeUWnzpEoqb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ImageEncoder(nn.Module):\n", | |
" def __init__(self, in_channels):\n", | |
" super().__init__()\n", | |
" self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn1 = nn.BatchNorm2d(32)\n", | |
" self.relu1 = nn.ReLU()\n", | |
"\n", | |
" self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(64)\n", | |
" self.relu2 = nn.ReLU()\n", | |
"\n", | |
" self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn3 = nn.BatchNorm2d(128)\n", | |
" self.relu3 = nn.ReLU()\n", | |
"\n", | |
" # Global Average Pooling to reduce spatial dimensions to 1x1\n", | |
" self.gap = nn.AdaptiveAvgPool2d((1, 1))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu1(x)\n", | |
"\n", | |
" x = self.conv2(x)\n", | |
" x = self.bn2(x)\n", | |
" x = self.relu2(x)\n", | |
"\n", | |
" x = self.conv3(x)\n", | |
" x = self.bn3(x)\n", | |
" x = self.relu3(x)\n", | |
"\n", | |
" x = self.gap(x)\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "z81MMxklEqhO" | |
}, | |
"execution_count": 29, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ImageDecoder(nn.Module):\n", | |
" def __init__(self, out_channels, initial_height, initial_width):\n", | |
" super(ImageDecoder, self).__init__()\n", | |
" self.conv_transpose1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn1 = nn.BatchNorm2d(64)\n", | |
" self.relu1 = nn.ReLU()\n", | |
"\n", | |
" self.conv_transpose2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(32)\n", | |
" self.relu2 = nn.ReLU()\n", | |
"\n", | |
" self.conv_transpose3 = nn.ConvTranspose2d(32, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn3 = nn.BatchNorm2d(out_channels)\n", | |
" self.relu3 = nn.ReLU()\n", | |
"\n", | |
" # Additional layer to ensure correct output dimensions\n", | |
" # This layer is only needed if the initial size cannot be exactly achieved through the strides and paddings chosen\n", | |
" self.final_resize = nn.AdaptiveAvgPool2d((initial_height, initial_width))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv_transpose1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu1(x)\n", | |
"\n", | |
" x = self.conv_transpose2(x)\n", | |
" x = self.bn2(x)\n", | |
" x = self.relu2(x)\n", | |
"\n", | |
" x = self.conv_transpose3(x)\n", | |
" x = self.bn3(x)\n", | |
" x = self.relu3(x)\n", | |
"\n", | |
" x = self.final_resize(x) # Ensure the output has the same HxW dimensions as the original input\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "gzQ35KxeE2Cz" | |
}, | |
"execution_count": 30, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class CustomRecurrence(nn.Module):\n", | |
" def __init__(self, in_channels, initial_height, initial_width, hidden_dim, num_layers=2, training=False):\n", | |
" super().__init__()\n", | |
" self.image_encoder = ImageEncoder(in_channels=in_channels)\n", | |
" self.positional_encoder = nn.Embedding(T, 8)\n", | |
" self.rnn = nn.LSTM(\n", | |
" input_size=128+8, # hardcoded for the time being\n", | |
" hidden_size=hidden_dim,\n", | |
" num_layers=num_layers,\n", | |
" batch_first=True\n", | |
" )\n", | |
" self.image_decoder = ImageDecoder(\n", | |
" out_channels=in_channels,\n", | |
" initial_height=initial_height,\n", | |
" initial_width=initial_width\n", | |
" )\n", | |
" self.training=training\n", | |
"\n", | |
" def forward(self, x, hidden_states=None):\n", | |
" batch_size, timesteps, channels, height, width = x.shape\n", | |
"\n", | |
" x = x.reshape(batch_size * timesteps, channels, height, width) # (b*t, c, h, w)\n", | |
"\n", | |
" latent_vectors = self.image_encoder(x) # (b*t, hidden_dim, 1, 1)\n", | |
" latent_vectors = latent_vectors.reshape(batch_size, timesteps, -1) # (b, t, 128)\n", | |
"\n", | |
" pos = torch.arange(timesteps).unsqueeze(0).repeat(batch_size, 1).to(device) # (b, t)\n", | |
" pos_embeds = self.positional_encoder(pos) # (b, t, 8)\n", | |
" latent_vectors = torch.cat([latent_vectors, pos_embeds], dim=-1) # (b, t, 128+8)\n", | |
"\n", | |
" if self.training:\n", | |
" rnn_outputs, _ = self.rnn(latent_vectors) # (b, t, hidden_dim)\n", | |
" else:\n", | |
" rnn_outputs, hidden_states = self.rnn(latent_vectors, hidden_states) # (b, t, hidden_dim)\n", | |
"\n", | |
" rnn_outputs = rnn_outputs.reshape(batch_size * timesteps, 128, 1, 1) # (b*t, hidden_dim, 1, 1)\n", | |
" reconstructed_x = self.image_decoder(rnn_outputs)\n", | |
" reconstructed_x = reconstructed_x.reshape(batch_size, timesteps, channels, height, width)\n", | |
"\n", | |
" if self.training:\n", | |
" return reconstructed_x\n", | |
" else:\n", | |
" return reconstructed_x, hidden_states" | |
], | |
"metadata": { | |
"id": "chqbXzZaE5Ok" | |
}, | |
"execution_count": 34, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Training Loop" | |
], | |
"metadata": { | |
"id": "KD_os31TMhEN" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = CustomRecurrence(\n", | |
" in_channels=C,\n", | |
" initial_height=H,\n", | |
" initial_width=W,\n", | |
" hidden_dim=hidden_dim,\n", | |
" training=True, # important parameter\n", | |
")\n", | |
"model.to(device)\n", | |
"\n", | |
"optimizer = torch.optim.Adam(\n", | |
" model.parameters(),\n", | |
" lr=1e-3\n", | |
")\n", | |
"\n", | |
"loss_fn = nn.MSELoss()" | |
], | |
"metadata": { | |
"id": "TYo28dSj49K_" | |
}, | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for epoch in range(epochs):\n", | |
" for step, batch in enumerate(dataloader):\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" batch_size = batch[\"pixel_values\"].shape[0]\n", | |
" batch = batch[\"pixel_values\"].to(device)\n", | |
"\n", | |
" t = torch.arange(0, T).flip((-1,)).repeat(batch_size, 1).to(device)\n", | |
"\n", | |
" input_images = q_sample(\n", | |
" x_start=batch.unsqueeze(1), # (B, 1, C, H, W)\n", | |
" t=t, # (B, t)\n", | |
" sqrt_alphas_cumprod=sqrt_alphas_cumprod.repeat(batch_size, 1),\n", | |
" sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod.repeat(batch_size, 1),\n", | |
" )\n", | |
"\n", | |
" reconstructed_images = model(input_images[:, :T-1, ...])\n", | |
" loss = loss_fn(reconstructed_images, input_images[:, 1:, ...])\n", | |
"\n", | |
"\n", | |
" if step % 100 == 0:\n", | |
" print(\"Loss:\", loss.item())\n", | |
"\n", | |
" loss.backward()\n", | |
" optimizer.step()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ImaLvHwjMiDf", | |
"outputId": "f0a12e6c-775d-4fad-8ac3-c15935f6ba5f" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Loss: 1.4094734191894531\n", | |
"Loss: 0.7404128909111023\n", | |
"Loss: 0.7145070433616638\n", | |
"Loss: 0.7277837991714478\n", | |
"Loss: 0.7170003652572632\n", | |
"Loss: 0.7365358471870422\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Inference Loop" | |
], | |
"metadata": { | |
"id": "PP6CFN4vMidh" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model.training=False\n", | |
"model = model.eval()" | |
], | |
"metadata": { | |
"id": "iNwii9gJ-Utu" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"with torch.no_grad():\n", | |
" current_input = torch.randn(\n", | |
" 1, 1, C, H, W\n", | |
" ).to(device)\n", | |
" hidden_state = None # Start with no hidden state\n", | |
"\n", | |
" for _ in range(timesteps):\n", | |
" # Forward pass\n", | |
" output, hidden_state = model(current_input, hidden_state)\n", | |
"\n", | |
" # Use output as next input\n", | |
" current_input = output" | |
], | |
"metadata": { | |
"id": "8NxLub4ZMjpH" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"output.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "rNmgRa7m-ZEW", | |
"outputId": "8cafec2a-40ec-4e8f-ca55-94d6e78512ea" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 1, 1, 28, 28])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"plt.imshow(output[0, 0, 0].detach().cpu().numpy(), cmap=\"gray\")\n", | |
"plt.axis(\"off\")\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 406 | |
}, | |
"id": "OkJeUd45JooW", | |
"outputId": "be0a63f0-2584-43f2-bee7-a5fb0aa0bd2d" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAHXUlEQVR4nO3cMY6NfR/HYUcmkZEggw6NiERiDyqFBViDBViBxjroprQAQUGrkikUFoDMIGGGDOftPnkLuTPHc875z/O4rvr+Jd8pZj7uwj2bz+fzEwBw4sSJk6MHAHB8iAIAEQUAIgoARBQAiCgAEFEAIKIAQDaO+uBsNlvlDhju6tWroydM2tvbW8vNumxtbY2eMGl/f3/hm4ODgxUsWZ6j/F9lbwoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD0Ajou9vb3REybt7++PnrBUx/3nOTw8HD1hCG8KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgs/l8Pj/Sg7PZqrcAsEJH+XPvTQGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4vSdPnoyeMGl7e3stN+v0+fPn0RMm7e7uLnyzt7e3giXL8fDhw9ETJu3s7Kzl5rjxpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCN0QP4ve3t7dETJr1582bhm5Mnj/e/QR49ejR6wqRPnz6t5WZddnZ2Rk+Y9OHDh9EThjjev6UArJUoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAZvP5fH6kB2ezVW/hX+RPPm533D+Id/PmzdETJn358mUtN+vy8ePH0RP+Okf5c3+8f0sBWCtRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAZGP0AH7v2rVroydM2t3dXcvNOt25c2f0hEnfvn1b+Obr168rWLIcV65cGT1h0vPnzxe+efHixfKHrJk3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgGwc9cFTp06tcsc/dnh4uPDNz58/V7BkOXZ3d0dPmLS/vz96wtK9evVq9IRJv379Gj1hqc6dOzd6wqSDg4PRE4bwpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFAPKf+SDen/BBPP7fy5cvR0+YtLm5uZabdblx48boCZN+/PgxesIQ3hQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgDZGD1gWW7fvr3wza1bt1awZDkeP348esKk9+/fr+Vmne7duzd6wqTz588vfLO1tbWCJctx//790RMmPXjwYOGb169fr2DJenlTACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAyMZRH/z+/fsqd/xjb9++XfjmOP9MFy9eHD1h0oULFxa+uX79+gqWLM/ly5dHT5h09uzZhW/OnDmzgiXL8ezZs9ETJr179270hCG8KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgPzVH8T7k5t1uXv37ugJk/7kg3h/crNOly5dGj1h0unTpxe+2dzcXMGS5Xj69OnoCZN8EA+Av54oABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAMpvP5/PRIwA4HrwpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQ/wHzssF0wLHC6wAAAABJRU5ErkJggg==\n" | |
}, | |
"metadata": {} | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment