Last active
May 4, 2021 12:35
-
-
Save ariG23498/af76b2b0b2c59cb6eb9aaf90fa75793d to your computer and use it in GitHub Desktop.
ConcatDataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "scratchpad", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/af76b2b0b2c59cb6eb9aaf90fa75793d/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lIYdn1woOS1n" | |
}, | |
"source": [ | |
"%%bash\n", | |
"mkdir A B C\n", | |
"mkdir A/sub_1 A/sub_2 A/sub_3\n", | |
"mkdir B/sub_1 B/sub_2 B/sub_3\n", | |
"mkdir C/sub_1 C/sub_2 C/sub_3" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hMENWXMNOoIs" | |
}, | |
"source": [ | |
"import os\n", | |
"\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import torch\n", | |
"import torchvision" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "PhGdo-IUOlTZ" | |
}, | |
"source": [ | |
"for folder in [\"A\", \"B\", \"C\"]:\n", | |
" for sub_folder in os.listdir(folder):\n", | |
" for i in range(2):\n", | |
" img = np.random.random((20,20))\n", | |
" plt.imsave(arr=img, fname=f\"{folder}/{sub_folder}/img_{i}.png\")" | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nEKVFXFhOwYD" | |
}, | |
"source": [ | |
"A_dataset = torchvision.datasets.ImageFolder(root = \"A\" , transform = torchvision.transforms.ToTensor())\n", | |
"B_dataset = torchvision.datasets.ImageFolder(root = \"B\" , transform = torchvision.transforms.ToTensor())\n", | |
"C_dataset = torchvision.datasets.ImageFolder(root = \"C\" , transform = torchvision.transforms.ToTensor())\n", | |
"\n", | |
"all_datasets = []\n", | |
"all_datasets.append(A_dataset)\n", | |
"all_datasets.append(B_dataset)\n", | |
"all_datasets.append(C_dataset)\n", | |
"\n", | |
"final_training_dataset = torch.utils.data.ConcatDataset(all_datasets)" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "iuoV2AAKVQSC" | |
}, | |
"source": [ | |
"for ind, c in enumerate(A_dataset.classes):\n", | |
" A_dataset.classes[ind] = f\"A_{c}\"\n", | |
"\n", | |
"for ind, c in enumerate(B_dataset.classes):\n", | |
" B_dataset.classes[ind] = f\"B_{c}\"\n", | |
"\n", | |
"for ind, c in enumerate(C_dataset.classes):\n", | |
" C_dataset.classes[ind] = f\"C_{c}\"" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "eOFGZN60VzvT", | |
"outputId": "6b8cc188-96b9-4823-9a16-c9bf63babe31" | |
}, | |
"source": [ | |
"A_dataset.classes, B_dataset.classes, C_dataset.classes, " | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(['A_sub_1', 'A_sub_2', 'A_sub_3'],\n", | |
" ['B_sub_1', 'B_sub_2', 'B_sub_3'],\n", | |
" ['C_sub_1', 'C_sub_2', 'C_sub_3'])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3OaidV76T_3L" | |
}, | |
"source": [ | |
"full_dl = torch.utils.data.DataLoader(final_training_dataset, batch_size = 1, shuffle = False) " | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "r9pI7wKmO-5s", | |
"outputId": "29084755-8320-4456-e220-90556dad8353" | |
}, | |
"source": [ | |
"for idx, element in enumerate(full_dl):\n", | |
" img, l = element\n", | |
" if len(A_dataset) - idx >=0:\n", | |
" print(A_dataset.classes[l])\n", | |
" elif len(A_dataset)+len(B_dataset) - idx >=0:\n", | |
" print(B_dataset.classes[l])\n", | |
" else:\n", | |
" print(C_dataset.classes[l])" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"A_sub_1\n", | |
"A_sub_1\n", | |
"A_sub_2\n", | |
"A_sub_2\n", | |
"A_sub_3\n", | |
"A_sub_3\n", | |
"A_sub_1\n", | |
"B_sub_1\n", | |
"B_sub_2\n", | |
"B_sub_2\n", | |
"B_sub_3\n", | |
"B_sub_3\n", | |
"B_sub_1\n", | |
"C_sub_1\n", | |
"C_sub_2\n", | |
"C_sub_2\n", | |
"C_sub_3\n", | |
"C_sub_3\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment