Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Created April 16, 2024 11:37
Show Gist options
  • Save ariG23498/8939e8878fb981dbecd7b25615c3f77e to your computer and use it in GitHub Desktop.
Save ariG23498/8939e8878fb981dbecd7b25615c3f77e to your computer and use it in GitHub Desktop.
rnn-diffusion.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMrSehNpQdJyMcli3XY+QIi",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/8939e8878fb981dbecd7b25615c3f77e/rnn-diffusion.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Imports"
],
"metadata": {
"id": "X_gRMLvAjLAt"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"from torch import nn"
],
"metadata": {
"id": "isxlqVBrivp7"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Get the images"
],
"metadata": {
"id": "dzOB-2QhjMHl"
}
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "4SctFzFHiq3C"
},
"outputs": [],
"source": [
"batch_size = 32\n",
"H = 128\n",
"W = 128\n",
"C = 3\n",
"\n",
"dim = 32\n",
"\n",
"# placeholder for batch of images\n",
"images = torch.ones(batch_size, C, H, W)"
]
},
{
"cell_type": "markdown",
"source": [
"# Create noisy images according to a schedule\n",
"\n",
"We have the original image $\\mathbf{x}_{0}$. Using the \"nice property\" of the forward diffusion process we use a schedule and get the noisy images $\\mathbf{x}_t$."
],
"metadata": {
"id": "cwcZP7TVjOOF"
}
},
{
"cell_type": "code",
"source": [
"t = 10 # defining the time steps\n",
"\n",
"# placeholder for noisy images using the variance schedule\n",
"# (https://huggingface.co/blog/annotated-diffusion#defining-the-forward-diffusion-process)\n",
"# (batch_size, t, C, H, W)\n",
"noisy_images = torch.randn(batch_size, t, C, H, W)"
],
"metadata": {
"id": "VtWEDMlqjDx4"
},
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class ImageEncoder(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super().__init__()\n",
" # Initialize convolutional layers with appropriate kernel size, stride, and padding\n",
" self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1)\n",
" self.bn1 = nn.BatchNorm2d(64)\n",
" self.relu1 = nn.ReLU()\n",
"\n",
" self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)\n",
" self.bn2 = nn.BatchNorm2d(128)\n",
" self.relu2 = nn.ReLU()\n",
"\n",
" self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)\n",
" self.bn3 = nn.BatchNorm2d(256)\n",
" self.relu3 = nn.ReLU()\n",
"\n",
" # Global Average Pooling to reduce spatial dimensions to 1x1\n",
" self.gap = nn.AdaptiveAvgPool2d((1, 1))\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu1(x)\n",
"\n",
" x = self.conv2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu2(x)\n",
"\n",
" x = self.conv3(x)\n",
" x = self.bn3(x)\n",
" x = self.relu3(x)\n",
"\n",
" x = self.gap(x) # Output shape: (batch_size, 256, 1, 1)\n",
" return x"
],
"metadata": {
"id": "6xh0jpExloAM"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class ImageDecoder(nn.Module):\n",
" def __init__(self, out_channels, initial_height, initial_width):\n",
" super(ImageDecoder, self).__init__()\n",
" # Initialize transpose convolutional layers to upscale feature maps\n",
" self.conv_transpose1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn1 = nn.BatchNorm2d(128)\n",
" self.relu1 = nn.ReLU()\n",
"\n",
" self.conv_transpose2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn2 = nn.BatchNorm2d(64)\n",
" self.relu2 = nn.ReLU()\n",
"\n",
" self.conv_transpose3 = nn.ConvTranspose2d(64, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn3 = nn.BatchNorm2d(out_channels)\n",
" self.relu3 = nn.ReLU()\n",
"\n",
" # Additional layer to ensure correct output dimensions\n",
" # This layer is only needed if the initial size cannot be exactly achieved through the strides and paddings chosen\n",
" self.final_resize = nn.AdaptiveAvgPool2d((initial_height, initial_width))\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_transpose1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu1(x)\n",
"\n",
" x = self.conv_transpose2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu2(x)\n",
"\n",
" x = self.conv_transpose3(x)\n",
" x = self.bn3(x)\n",
" x = self.relu3(x)\n",
"\n",
" x = self.final_resize(x) # Ensure the output has the same HxW dimensions as the original input\n",
" return x"
],
"metadata": {
"id": "QNLRXdTAl0aC"
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class CustomRecurrence(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim, num_layers=1):\n",
" super().__init__()\n",
" self.rnn = nn.RNN(\n",
" input_size=input_dim,\n",
" hidden_size=hidden_dim,\n",
" num_layers=num_layers,\n",
" batch_first=True\n",
" )\n",
" self.fc = nn.Linear(hidden_dim, input_dim) # Output layer to match input dimensions\n",
"\n",
" def forward(self, x):\n",
" # x shape: (batch_size, t, dim)\n",
" out, _ = self.rnn(x) # out shape: (batch_size, t, hidden_dim)\n",
" out = self.fc(out) # Final output shape: (batch_size, t, dim)\n",
" return out"
],
"metadata": {
"id": "rhZG2kk_m3qs"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"noisy_images_reshaped = noisy_images.reshape(batch_size * t, C, H, W)\n",
"\n",
"# encode noises to latent\n",
"encoder = ImageEncoder(in_channels=C)\n",
"latent_vectors = encoder(noisy_images_reshaped)\n",
"\n",
"# reshape the latent to align it to RNN inputs\n",
"latent_vectors = latent_vectors.reshape(batch_size, t, -1)\n",
"\n",
"custom_recurrent_model = CustomRecurrence(input_dim=256, hidden_dim=256)\n",
"output_tensors = custom_recurrent_model(latent_vectors)\n",
"\n",
"# reshape output tensor to align to upsample\n",
"output_tensors = output_tensors.reshape(batch_size * t, 256, 1, 1)\n",
"\n",
"# decode latent to noise\n",
"decoder = ImageDecoder(out_channels=C, initial_height=H, initial_width=W)\n",
"reconstructed_noises = decoder(output_tensors)\n",
"\n",
"# reshape for loss\n",
"reconstructed_noises = reconstructed_noises.reshape(batch_size, t, C, H, W)"
],
"metadata": {
"id": "_YiBv7Ukj2jH"
},
"execution_count": 32,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Loss"
],
"metadata": {
"id": "lHCymFkYr314"
}
},
{
"cell_type": "code",
"source": [
"print(f\"{noisy_images.shape=}\")\n",
"print(f\"{reconstructed_noises.shape=}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vgtTplYQsWXk",
"outputId": "7a25b7c6-d50d-4a98-e1ab-1286caeebdf7"
},
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"noisy_images.shape=torch.Size([32, 10, 3, 128, 128])\n",
"reconstructed_noises.shape=torch.Size([32, 10, 3, 128, 128])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"loss = nn.MSELoss()\n",
"output = loss(reconstructed_noises[:, 1:, ...], noisy_images[:, :t-1, ...])\n",
"\n",
"print(output)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Nqnrc3PUr2z9",
"outputId": "77b49aa1-25fa-428c-ed93-209eb3e3fcd6"
},
"execution_count": 36,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor(1.5229, grad_fn=<MseLossBackward0>)\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment