Created
April 16, 2024 11:37
-
-
Save ariG23498/8939e8878fb981dbecd7b25615c3f77e to your computer and use it in GitHub Desktop.
rnn-diffusion.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyMrSehNpQdJyMcli3XY+QIi", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/8939e8878fb981dbecd7b25615c3f77e/rnn-diffusion.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Imports" | |
], | |
"metadata": { | |
"id": "X_gRMLvAjLAt" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"from torch import nn" | |
], | |
"metadata": { | |
"id": "isxlqVBrivp7" | |
}, | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Get the images" | |
], | |
"metadata": { | |
"id": "dzOB-2QhjMHl" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"id": "4SctFzFHiq3C" | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 32\n", | |
"H = 128\n", | |
"W = 128\n", | |
"C = 3\n", | |
"\n", | |
"dim = 32\n", | |
"\n", | |
"# placeholder for batch of images\n", | |
"images = torch.ones(batch_size, C, H, W)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Create noisy images according to a schedule\n", | |
"\n", | |
"We have the original image $\\mathbf{x}_{0}$. Using the \"nice property\" of the forward diffusion process we use a schedule and get the noisy images $\\mathbf{x}_t$." | |
], | |
"metadata": { | |
"id": "cwcZP7TVjOOF" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"t = 10 # defining the time steps\n", | |
"\n", | |
"# placeholder for noisy images using the variance schedule\n", | |
"# (https://huggingface.co/blog/annotated-diffusion#defining-the-forward-diffusion-process)\n", | |
"# (batch_size, t, C, H, W)\n", | |
"noisy_images = torch.randn(batch_size, t, C, H, W)" | |
], | |
"metadata": { | |
"id": "VtWEDMlqjDx4" | |
}, | |
"execution_count": 20, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ImageEncoder(nn.Module):\n", | |
" def __init__(self, in_channels):\n", | |
" super().__init__()\n", | |
" # Initialize convolutional layers with appropriate kernel size, stride, and padding\n", | |
" self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn1 = nn.BatchNorm2d(64)\n", | |
" self.relu1 = nn.ReLU()\n", | |
"\n", | |
" self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(128)\n", | |
" self.relu2 = nn.ReLU()\n", | |
"\n", | |
" self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)\n", | |
" self.bn3 = nn.BatchNorm2d(256)\n", | |
" self.relu3 = nn.ReLU()\n", | |
"\n", | |
" # Global Average Pooling to reduce spatial dimensions to 1x1\n", | |
" self.gap = nn.AdaptiveAvgPool2d((1, 1))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu1(x)\n", | |
"\n", | |
" x = self.conv2(x)\n", | |
" x = self.bn2(x)\n", | |
" x = self.relu2(x)\n", | |
"\n", | |
" x = self.conv3(x)\n", | |
" x = self.bn3(x)\n", | |
" x = self.relu3(x)\n", | |
"\n", | |
" x = self.gap(x) # Output shape: (batch_size, 256, 1, 1)\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "6xh0jpExloAM" | |
}, | |
"execution_count": 21, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ImageDecoder(nn.Module):\n", | |
" def __init__(self, out_channels, initial_height, initial_width):\n", | |
" super(ImageDecoder, self).__init__()\n", | |
" # Initialize transpose convolutional layers to upscale feature maps\n", | |
" self.conv_transpose1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn1 = nn.BatchNorm2d(128)\n", | |
" self.relu1 = nn.ReLU()\n", | |
"\n", | |
" self.conv_transpose2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(64)\n", | |
" self.relu2 = nn.ReLU()\n", | |
"\n", | |
" self.conv_transpose3 = nn.ConvTranspose2d(64, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n", | |
" self.bn3 = nn.BatchNorm2d(out_channels)\n", | |
" self.relu3 = nn.ReLU()\n", | |
"\n", | |
" # Additional layer to ensure correct output dimensions\n", | |
" # This layer is only needed if the initial size cannot be exactly achieved through the strides and paddings chosen\n", | |
" self.final_resize = nn.AdaptiveAvgPool2d((initial_height, initial_width))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv_transpose1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu1(x)\n", | |
"\n", | |
" x = self.conv_transpose2(x)\n", | |
" x = self.bn2(x)\n", | |
" x = self.relu2(x)\n", | |
"\n", | |
" x = self.conv_transpose3(x)\n", | |
" x = self.bn3(x)\n", | |
" x = self.relu3(x)\n", | |
"\n", | |
" x = self.final_resize(x) # Ensure the output has the same HxW dimensions as the original input\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "QNLRXdTAl0aC" | |
}, | |
"execution_count": 22, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class CustomRecurrence(nn.Module):\n", | |
" def __init__(self, input_dim, hidden_dim, num_layers=1):\n", | |
" super().__init__()\n", | |
" self.rnn = nn.RNN(\n", | |
" input_size=input_dim,\n", | |
" hidden_size=hidden_dim,\n", | |
" num_layers=num_layers,\n", | |
" batch_first=True\n", | |
" )\n", | |
" self.fc = nn.Linear(hidden_dim, input_dim) # Output layer to match input dimensions\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # x shape: (batch_size, t, dim)\n", | |
" out, _ = self.rnn(x) # out shape: (batch_size, t, hidden_dim)\n", | |
" out = self.fc(out) # Final output shape: (batch_size, t, dim)\n", | |
" return out" | |
], | |
"metadata": { | |
"id": "rhZG2kk_m3qs" | |
}, | |
"execution_count": 23, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"noisy_images_reshaped = noisy_images.reshape(batch_size * t, C, H, W)\n", | |
"\n", | |
"# encode noises to latent\n", | |
"encoder = ImageEncoder(in_channels=C)\n", | |
"latent_vectors = encoder(noisy_images_reshaped)\n", | |
"\n", | |
"# reshape the latent to align it to RNN inputs\n", | |
"latent_vectors = latent_vectors.reshape(batch_size, t, -1)\n", | |
"\n", | |
"custom_recurrent_model = CustomRecurrence(input_dim=256, hidden_dim=256)\n", | |
"output_tensors = custom_recurrent_model(latent_vectors)\n", | |
"\n", | |
"# reshape output tensor to align to upsample\n", | |
"output_tensors = output_tensors.reshape(batch_size * t, 256, 1, 1)\n", | |
"\n", | |
"# decode latent to noise\n", | |
"decoder = ImageDecoder(out_channels=C, initial_height=H, initial_width=W)\n", | |
"reconstructed_noises = decoder(output_tensors)\n", | |
"\n", | |
"# reshape for loss\n", | |
"reconstructed_noises = reconstructed_noises.reshape(batch_size, t, C, H, W)" | |
], | |
"metadata": { | |
"id": "_YiBv7Ukj2jH" | |
}, | |
"execution_count": 32, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Loss" | |
], | |
"metadata": { | |
"id": "lHCymFkYr314" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(f\"{noisy_images.shape=}\")\n", | |
"print(f\"{reconstructed_noises.shape=}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "vgtTplYQsWXk", | |
"outputId": "7a25b7c6-d50d-4a98-e1ab-1286caeebdf7" | |
}, | |
"execution_count": 33, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"noisy_images.shape=torch.Size([32, 10, 3, 128, 128])\n", | |
"reconstructed_noises.shape=torch.Size([32, 10, 3, 128, 128])\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"loss = nn.MSELoss()\n", | |
"output = loss(reconstructed_noises[:, 1:, ...], noisy_images[:, :t-1, ...])\n", | |
"\n", | |
"print(output)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Nqnrc3PUr2z9", | |
"outputId": "77b49aa1-25fa-428c-ed93-209eb3e3fcd6" | |
}, | |
"execution_count": 36, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"tensor(1.5229, grad_fn=<MseLossBackward0>)\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment