Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Created April 18, 2024 07:51
Show Gist options
  • Save ariG23498/73c1d4c27d75268fd9973930dc80a2ad to your computer and use it in GitHub Desktop.
Save ariG23498/73c1d4c27d75268fd9973930dc80a2ad to your computer and use it in GitHub Desktop.
autoregressive-diffusion-gru.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/73c1d4c27d75268fd9973930dc80a2ad/autoregressive-diffusion-gru.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "02vydzVb6VWO"
},
"source": [
"This notebook is heavily inspired from: https://huggingface.co/blog/annotated-diffusion"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-CFKktKV2qqo"
},
"source": [
"## Setup and Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BmiGcVlQ64dD",
"outputId": "6e9edf89-62e7-4533-8fb9-96dd0f6558a2"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
],
"source": [
"!pip install --upgrade -qq datasets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "hzUHzDmy2kOy"
},
"outputs": [],
"source": [
"import random\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from datasets import load_dataset\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from torchvision import transforms"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rlNouEPu7ART"
},
"source": [
"## Configurations"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3HNObPXp7Bff",
"outputId": "d6bc5730-1833-488c-9964-140476d5b092"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"device='cuda'\n"
]
}
],
"source": [
"B = 1024\n",
"H = W = 28\n",
"C = 1\n",
"T = 20\n",
"\n",
"hidden_dim = 128\n",
"epochs = 10\n",
"lr = 1e-4\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"{device=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGEzhe8s7B1p"
},
"source": [
"## Dataset and Loaders"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 145,
"referenced_widgets": [
"aabc8f376da24c189f7e254f0e10463f",
"b0ade1942f494b65b7c0e4766da73d6e",
"84001a8771e84ae4ae8490e109c10d12",
"5a7b1a406d3b4304bd09275f2f1cea0d",
"81f27436ace945cebc703f20c907af7c",
"194358c3c5624917809abc4c96a88258",
"586da1075f2e466fb7e2c0889797525a",
"f2fc2b4f120d4e6d91c20d86e9a7200e",
"7afa5f2062754840b1768aa5c16c6b1b",
"d0e8a5f96fe64169b8f57cafec5a3712",
"6f72f0a7f8354fb093a7cf5b0b26bc40",
"13335f706b33407aad4fc17dc775c442",
"d2074f1151724105bea216cbcbdf2b1a",
"475a7b917d5f4498a8994f70a698778a",
"1b1285ff9f1b4e47be202a49ef298b7f",
"7dd69b0868d84ceca97e7cd9ced274ef",
"1c7901dd4dad459392c76ddef53826fb",
"4a7bcaecfa214392996cdc625d035407",
"f0dd4a8852544be1a5419783d933d3c2",
"0866e4ffae86463b9ed5119625bde4f5",
"4aa286bebafa4899b74863e320d3cf64",
"f00a7f09029944169ecd25447f857f04",
"da168da36c3d4c23893fa8075cf11147",
"9b3e2e698ff74613a18eea9bb5ddd86b",
"871aafcc0b5440f8a7ca2a9b547e317c",
"ee51dc1bece8423e8d9cbe23dafe6b3c",
"67386296cc604ddabacaa6e3c98813b9",
"3394bbd34bfb4026873b36b0f7681290",
"73fa173613294ec4ad18a8f4b2f3f82e",
"f6b863493fb04233a83ca107ece30b55",
"a2dd7b3497a143b582ff876a45c42e0c",
"e3b70cabe29743658054c86fc830003b",
"709d4a0bad844d7d9e512d08dffc954e",
"c6674878ba944e77b376a6f69f09e469",
"ec4ae507bb4c4c8288cbf3c165204bb8",
"2c47342b29fc493db35eaed7d2c41738",
"9a804266219c4171845f56faac9b0ed6",
"caf7a1f51a2745ac8b71588c6a7b027f",
"b03b06120b5141fda54a3536bcc495da",
"baa5227cc3204deda66619ad1e96f466",
"78a82c4f84414ad48c1f94fae6cb94e5",
"701a70a97018426d998cb607d8ac0a17",
"40f38886d99947eb9f3d02e0335cf82a",
"8f2663f99d0c4c1fa5898830f9e2c13a"
]
},
"id": "PdIqVYuj6axU",
"outputId": "634b9086-555c-473b-a2e1-ceceb44992dd"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading data: 0%| | 0.00/30.9M [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "aabc8f376da24c189f7e254f0e10463f"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading data: 0%| | 0.00/5.18M [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "13335f706b33407aad4fc17dc775c442"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating train split: 0%| | 0/60000 [00:00<?, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "da168da36c3d4c23893fa8075cf11147"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating test split: 0%| | 0/10000 [00:00<?, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "c6674878ba944e77b376a6f69f09e469"
}
},
"metadata": {}
}
],
"source": [
"# load dataset from the hub\n",
"dataset = load_dataset(\"fashion_mnist\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "SbnfSuvR6eel"
},
"outputs": [],
"source": [
"# define image transformations\n",
"transform = transforms.Compose([\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Lambda(lambda x: (x * 2) - 1)\n",
"])\n",
"\n",
"# define function\n",
"def transforms(examples):\n",
" examples[\"pixel_values\"] = [transform(image.convert(\"L\")) for image in examples[\"image\"]]\n",
" del examples[\"image\"]\n",
" return examples\n",
"\n",
"transformed_dataset = (\n",
" dataset\n",
" .with_transform(transforms)\n",
" .remove_columns(\"label\")\n",
")\n",
"\n",
"# create dataloader\n",
"dataloader = DataLoader(\n",
" transformed_dataset[\"train\"],\n",
" batch_size=B,\n",
" shuffle=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GneDUV5J74KJ",
"outputId": "13a628f5-7988-4cce-f145-e92ff2bfd47c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"dict_keys(['pixel_values'])\n",
"torch.Size([1024, 1, 28, 28])\n"
]
}
],
"source": [
"# Get a batch of data and check the shape\n",
"batch = next(iter(dataloader))\n",
"print(batch.keys())\n",
"print(batch[\"pixel_values\"].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ggg6okjO8NFS"
},
"source": [
"## Forward Diffusion Process"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "9_7ipOKVIbFn"
},
"outputs": [],
"source": [
"# Define a linear schedule\n",
"def linear_beta_schedule(T):\n",
" beta_start = 0.0001\n",
" beta_end = 0.02\n",
" return torch.linspace(beta_start, beta_end, T)\n",
"\n",
"# define beta schedule\n",
"betas = linear_beta_schedule(T=T)\n",
"\n",
"# define alphas = 1 - beta\n",
"alphas = 1.0 - betas\n",
"alphas_cumprod = torch.cumprod(alphas, axis=0)\n",
"sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)\n",
"sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)\n",
"\n",
"def extract(a, t, x_shape):\n",
" batch_size = t.shape[0]\n",
" out = a.gather(dim=-1, index=t.cpu())\n",
" return out.reshape(\n",
" batch_size, T, *((1,) * (len(x_shape) - 2))\n",
" ).to(t.device)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "pmHZGKa8Lhbz"
},
"outputs": [],
"source": [
"# forward diffusion\n",
"def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None):\n",
" if noise is None:\n",
" noise = torch.randn_like(x_start)\n",
"\n",
" sqrt_alphas_cumprod_t = extract(\n",
" sqrt_alphas_cumprod,\n",
" t,\n",
" x_start.shape\n",
" )\n",
" sqrt_one_minus_alphas_cumprod_t = extract(\n",
" sqrt_one_minus_alphas_cumprod,\n",
" t,\n",
" x_start.shape\n",
" )\n",
"\n",
" return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "uWgxpZPj8sY6"
},
"outputs": [],
"source": [
"# create a batch of noise images\n",
"t = torch.arange(0, T).flip((-1,)).repeat(B, 1)\n",
"input_images = q_sample(\n",
" x_start=batch[\"pixel_values\"].unsqueeze(1), # (B, 1, C, H, W)\n",
" t=t, # (B, timesteps)\n",
" sqrt_alphas_cumprod=sqrt_alphas_cumprod.repeat(B, 1),\n",
" sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod.repeat(B, 1),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aQ0H8DVxEJqF"
},
"source": [
"### Viz the forward process"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
},
"id": "I-hoBcj3DNSf",
"outputId": "dd05fa4f-0eca-4619-a7ba-137f4c40d8a4"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 5 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"idx = random.randint(0, B-1)\n",
"for i in range(1, 6):\n",
" plt.subplot(1, 5, i)\n",
" plt.imshow(input_images[idx, (T//5 * i)-1].permute(1, 2, 0), cmap=\"gray\")\n",
" plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-CeUWnzpEoqb"
},
"source": [
"## Define the RNN goodies"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "z81MMxklEqhO"
},
"outputs": [],
"source": [
"class ImageEncoder(nn.Module):\n",
" def __init__(self, in_channels=1, out_channels=128):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)\n",
" self.bn1 = nn.BatchNorm2d(32)\n",
" self.relu1 = nn.ReLU()\n",
"\n",
" self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)\n",
" self.bn2 = nn.BatchNorm2d(64)\n",
" self.relu2 = nn.ReLU()\n",
"\n",
" self.conv3 = nn.Conv2d(64, out_channels, kernel_size=3, stride=2, padding=1)\n",
" self.bn3 = nn.BatchNorm2d(out_channels)\n",
" self.relu3 = nn.ReLU()\n",
"\n",
" # Global Average Pooling to reduce spatial dimensions to 1x1\n",
" self.gap = nn.AdaptiveAvgPool2d((1, 1))\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu1(x)\n",
"\n",
" x = self.conv2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu2(x)\n",
"\n",
" x = self.conv3(x)\n",
" x = self.bn3(x)\n",
" x = self.relu3(x)\n",
"\n",
" x = self.gap(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "gzQ35KxeE2Cz"
},
"outputs": [],
"source": [
"class ImageDecoder(nn.Module):\n",
" def __init__(self, in_channels=128, out_channels=1, initial_height=H, initial_width=W):\n",
" super(ImageDecoder, self).__init__()\n",
" self.conv_transpose1 = nn.ConvTranspose2d(in_channels, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn1 = nn.BatchNorm2d(64)\n",
" self.relu1 = nn.ReLU()\n",
"\n",
" self.conv_transpose2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn2 = nn.BatchNorm2d(32)\n",
" self.relu2 = nn.ReLU()\n",
"\n",
" self.conv_transpose3 = nn.ConvTranspose2d(32, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)\n",
" self.bn3 = nn.BatchNorm2d(out_channels)\n",
" self.relu3 = nn.ReLU()\n",
"\n",
" # Additional layer to ensure correct output dimensions\n",
" # This layer is only needed if the initial size cannot be exactly achieved through the strides and paddings chosen\n",
" self.final_resize = nn.AdaptiveAvgPool2d((initial_height, initial_width))\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_transpose1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu1(x)\n",
"\n",
" x = self.conv_transpose2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu2(x)\n",
"\n",
" x = self.conv_transpose3(x)\n",
" x = self.bn3(x)\n",
" x = self.relu3(x)\n",
"\n",
" x = self.final_resize(x) # Ensure the output has the same HxW dimensions as the original input\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "chqbXzZaE5Ok"
},
"outputs": [],
"source": [
"class CustomRecurrence(nn.Module):\n",
" def __init__(\n",
" self,\n",
" in_channels=1,\n",
" initial_height=H,\n",
" initial_width=W,\n",
" hidden_dim=hidden_dim,\n",
" num_layers=2,\n",
" training=False\n",
" ):\n",
" super().__init__()\n",
" self.hidden_dim = hidden_dim\n",
" self.image_encoder = ImageEncoder(in_channels=in_channels)\n",
" self.rnn = nn.GRU(\n",
" input_size=hidden_dim,\n",
" hidden_size=hidden_dim,\n",
" num_layers=num_layers,\n",
" batch_first=True\n",
" )\n",
" self.image_decoder = ImageDecoder(\n",
" in_channels=hidden_dim,\n",
" out_channels=in_channels,\n",
" initial_height=initial_height,\n",
" initial_width=initial_width\n",
" )\n",
" self.training=training\n",
"\n",
" def forward(self, x, hidden_states=None):\n",
" batch_size, timesteps, channels, height, width = x.shape\n",
" x = x.reshape(batch_size * timesteps, channels, height, width) # (b*t, c, h, w)\n",
"\n",
" # Encode the image\n",
" latent_vectors = self.image_encoder(x) # (b*t, hidden_dim, 1, 1)\n",
" latent_vectors = latent_vectors.reshape(batch_size, timesteps, -1) # (b, t, hidden_dim)\n",
"\n",
" # Use RNNs to formulate the reverse diffusion process\n",
" if self.training:\n",
" rnn_outputs, _ = self.rnn(latent_vectors) # (b, t, hidden_dim)\n",
" else:\n",
" rnn_outputs, hidden_states = self.rnn(latent_vectors, hidden_states) # (b, t, hidden_dim)\n",
"\n",
" # Decode the images\n",
" rnn_outputs = rnn_outputs.reshape(batch_size * timesteps, self.hidden_dim, 1, 1) # (b*t, hidden_dim, 1, 1)\n",
" reconstructed_x = self.image_decoder(rnn_outputs)\n",
" reconstructed_x = reconstructed_x.reshape(batch_size, timesteps, channels, height, width)\n",
"\n",
" if self.training:\n",
" return reconstructed_x\n",
" else:\n",
" return reconstructed_x, hidden_states"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KD_os31TMhEN"
},
"source": [
"## Training Loop"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "TYo28dSj49K_"
},
"outputs": [],
"source": [
"model = CustomRecurrence(training=True)\n",
"model.to(device)\n",
"\n",
"optimizer = torch.optim.Adam(\n",
" model.parameters(),\n",
" lr=lr,\n",
")\n",
"\n",
"loss_fn = nn.MSELoss()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ImaLvHwjMiDf",
"outputId": "daa51f31-9bdd-45c4-b6f3-57c2eb54a6fc"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch:0\n",
"\tLoss:1.45117\n",
"Epoch:1\n",
"\tLoss:0.71674\n",
"Epoch:2\n",
"\tLoss:0.68546\n",
"Epoch:3\n",
"\tLoss:0.66598\n",
"Epoch:4\n",
"\tLoss:0.66175\n",
"Epoch:5\n",
"\tLoss:0.65562\n",
"Epoch:6\n",
"\tLoss:0.65633\n",
"Epoch:7\n",
"\tLoss:0.65157\n",
"Epoch:8\n",
"\tLoss:0.65766\n",
"Epoch:9\n",
"\tLoss:0.65651\n"
]
}
],
"source": [
"for epoch in range(epochs):\n",
" print(f\"Epoch:{epoch}\")\n",
" for step, batch in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
"\n",
" batch = batch[\"pixel_values\"].to(device)\n",
" batch_size = batch.shape[0]\n",
"\n",
"\n",
" t = torch.arange(0, T).flip((-1,)).repeat(batch_size, 1).to(device)\n",
" input_images = q_sample(\n",
" x_start=batch.unsqueeze(1), # (B, 1, C, H, W)\n",
" t=t, # (B, t)\n",
" sqrt_alphas_cumprod=sqrt_alphas_cumprod.repeat(batch_size, 1),\n",
" sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod.repeat(batch_size, 1),\n",
" )\n",
"\n",
" # remember to offset by 1 -- we are learning an autoregressive model\n",
" reconstructed_images = model(input_images[:, :T-1, ...])\n",
" loss = loss_fn(reconstructed_images, input_images[:, 1:, ...])\n",
"\n",
" if step % 100 == 0:\n",
" print(f\"\\tLoss:{loss.item():0.5f}\")\n",
"\n",
" loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PP6CFN4vMidh"
},
"source": [
"## Inference Loop"
]
},
{
"cell_type": "code",
"source": [
"def plot_images(image_list, rows=50, cols=10):\n",
" # Calculate dynamic figure size, assuming each image subplot is 1x1 inch size\n",
" figsize = (cols, rows) # You can adjust the scale factor here for smaller or larger images\n",
" fig, axes = plt.subplots(rows, cols, figsize=figsize)\n",
" for i, ax in enumerate(axes.flat):\n",
" if i < len(image_list):\n",
" ax.imshow(image_list[i][0, 0, 0], cmap='gray', aspect='auto')\n",
" ax.axis('off') # Turn off axis labels and ticks\n",
" else:\n",
" ax.axis('off') # Ensure empty subplots remain without axes\n",
"\n",
" plt.subplots_adjust(wspace=0.05, hspace=0.05) # Adjust spacing between images\n",
" plt.show()"
],
"metadata": {
"id": "cUWGc5paBq5z"
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "iNwii9gJ-Utu"
},
"outputs": [],
"source": [
"model.training=False\n",
"model = model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "8NxLub4ZMjpH"
},
"outputs": [],
"source": [
"generated_outputs = []\n",
"with torch.no_grad():\n",
" current_input = torch.randn(\n",
" 1, 1, C, H, W\n",
" ).to(device)\n",
" hidden_state = torch.zeros(2, 1, hidden_dim).to(device) # num_layers, t, input_channels\n",
"\n",
" for _ in range(T):\n",
" # Forward pass\n",
" output, hidden_state = model(current_input, hidden_state)\n",
"\n",
" # Use output as next input\n",
" current_input = output\n",
"\n",
" generated_outputs.append(output.cpu().numpy())"
]
},
{
"cell_type": "code",
"source": [
"num_gen_images = len(generated_outputs)\n",
"num_rows = num_gen_images // 10\n",
"num_cols = num_gen_images // num_rows\n",
"plot_images(generated_outputs, rows=num_rows, cols=num_cols)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
},
"id": "nadW4XaaBDXk",
"outputId": "38df5f7f-01a9-4ced-937a-539576e35a5f"
},
"execution_count": 28,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x200 with 20 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": [],
"authorship_tag": "ABX9TyP/P4AFUVVMmS88DTTrynJo",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"aabc8f376da24c189f7e254f0e10463f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_b0ade1942f494b65b7c0e4766da73d6e",
"IPY_MODEL_84001a8771e84ae4ae8490e109c10d12",
"IPY_MODEL_5a7b1a406d3b4304bd09275f2f1cea0d"
],
"layout": "IPY_MODEL_81f27436ace945cebc703f20c907af7c"
}
},
"b0ade1942f494b65b7c0e4766da73d6e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_194358c3c5624917809abc4c96a88258",
"placeholder": "​",
"style": "IPY_MODEL_586da1075f2e466fb7e2c0889797525a",
"value": "Downloading data: 100%"
}
},
"84001a8771e84ae4ae8490e109c10d12": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f2fc2b4f120d4e6d91c20d86e9a7200e",
"max": 30931277,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_7afa5f2062754840b1768aa5c16c6b1b",
"value": 30931277
}
},
"5a7b1a406d3b4304bd09275f2f1cea0d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_d0e8a5f96fe64169b8f57cafec5a3712",
"placeholder": "​",
"style": "IPY_MODEL_6f72f0a7f8354fb093a7cf5b0b26bc40",
"value": " 30.9M/30.9M [00:00&lt;00:00, 42.3MB/s]"
}
},
"81f27436ace945cebc703f20c907af7c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"194358c3c5624917809abc4c96a88258": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"586da1075f2e466fb7e2c0889797525a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"f2fc2b4f120d4e6d91c20d86e9a7200e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7afa5f2062754840b1768aa5c16c6b1b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"d0e8a5f96fe64169b8f57cafec5a3712": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6f72f0a7f8354fb093a7cf5b0b26bc40": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"13335f706b33407aad4fc17dc775c442": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_d2074f1151724105bea216cbcbdf2b1a",
"IPY_MODEL_475a7b917d5f4498a8994f70a698778a",
"IPY_MODEL_1b1285ff9f1b4e47be202a49ef298b7f"
],
"layout": "IPY_MODEL_7dd69b0868d84ceca97e7cd9ced274ef"
}
},
"d2074f1151724105bea216cbcbdf2b1a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1c7901dd4dad459392c76ddef53826fb",
"placeholder": "​",
"style": "IPY_MODEL_4a7bcaecfa214392996cdc625d035407",
"value": "Downloading data: 100%"
}
},
"475a7b917d5f4498a8994f70a698778a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f0dd4a8852544be1a5419783d933d3c2",
"max": 5175617,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_0866e4ffae86463b9ed5119625bde4f5",
"value": 5175617
}
},
"1b1285ff9f1b4e47be202a49ef298b7f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4aa286bebafa4899b74863e320d3cf64",
"placeholder": "​",
"style": "IPY_MODEL_f00a7f09029944169ecd25447f857f04",
"value": " 5.18M/5.18M [00:00&lt;00:00, 36.4MB/s]"
}
},
"7dd69b0868d84ceca97e7cd9ced274ef": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1c7901dd4dad459392c76ddef53826fb": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4a7bcaecfa214392996cdc625d035407": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"f0dd4a8852544be1a5419783d933d3c2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"0866e4ffae86463b9ed5119625bde4f5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"4aa286bebafa4899b74863e320d3cf64": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f00a7f09029944169ecd25447f857f04": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"da168da36c3d4c23893fa8075cf11147": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_9b3e2e698ff74613a18eea9bb5ddd86b",
"IPY_MODEL_871aafcc0b5440f8a7ca2a9b547e317c",
"IPY_MODEL_ee51dc1bece8423e8d9cbe23dafe6b3c"
],
"layout": "IPY_MODEL_67386296cc604ddabacaa6e3c98813b9"
}
},
"9b3e2e698ff74613a18eea9bb5ddd86b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_3394bbd34bfb4026873b36b0f7681290",
"placeholder": "​",
"style": "IPY_MODEL_73fa173613294ec4ad18a8f4b2f3f82e",
"value": "Generating train split: 100%"
}
},
"871aafcc0b5440f8a7ca2a9b547e317c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f6b863493fb04233a83ca107ece30b55",
"max": 60000,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_a2dd7b3497a143b582ff876a45c42e0c",
"value": 60000
}
},
"ee51dc1bece8423e8d9cbe23dafe6b3c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e3b70cabe29743658054c86fc830003b",
"placeholder": "​",
"style": "IPY_MODEL_709d4a0bad844d7d9e512d08dffc954e",
"value": " 60000/60000 [00:00&lt;00:00, 161425.71 examples/s]"
}
},
"67386296cc604ddabacaa6e3c98813b9": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3394bbd34bfb4026873b36b0f7681290": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"73fa173613294ec4ad18a8f4b2f3f82e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"f6b863493fb04233a83ca107ece30b55": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a2dd7b3497a143b582ff876a45c42e0c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"e3b70cabe29743658054c86fc830003b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"709d4a0bad844d7d9e512d08dffc954e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"c6674878ba944e77b376a6f69f09e469": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_ec4ae507bb4c4c8288cbf3c165204bb8",
"IPY_MODEL_2c47342b29fc493db35eaed7d2c41738",
"IPY_MODEL_9a804266219c4171845f56faac9b0ed6"
],
"layout": "IPY_MODEL_caf7a1f51a2745ac8b71588c6a7b027f"
}
},
"ec4ae507bb4c4c8288cbf3c165204bb8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_b03b06120b5141fda54a3536bcc495da",
"placeholder": "​",
"style": "IPY_MODEL_baa5227cc3204deda66619ad1e96f466",
"value": "Generating test split: 100%"
}
},
"2c47342b29fc493db35eaed7d2c41738": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_78a82c4f84414ad48c1f94fae6cb94e5",
"max": 10000,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_701a70a97018426d998cb607d8ac0a17",
"value": 10000
}
},
"9a804266219c4171845f56faac9b0ed6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_40f38886d99947eb9f3d02e0335cf82a",
"placeholder": "​",
"style": "IPY_MODEL_8f2663f99d0c4c1fa5898830f9e2c13a",
"value": " 10000/10000 [00:00&lt;00:00, 105401.47 examples/s]"
}
},
"caf7a1f51a2745ac8b71588c6a7b027f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b03b06120b5141fda54a3536bcc495da": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"baa5227cc3204deda66619ad1e96f466": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"78a82c4f84414ad48c1f94fae6cb94e5": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"701a70a97018426d998cb607d8ac0a17": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"40f38886d99947eb9f3d02e0335cf82a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"8f2663f99d0c4c1fa5898830f9e2c13a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment