Skip to content

Instantly share code, notes, and snippets.

@ankitshekhawat
Created September 18, 2019 09:25
Show Gist options
  • Save ankitshekhawat/69c6759b6ac3347c5be5b2360e9efa6d to your computer and use it in GitHub Desktop.
Save ankitshekhawat/69c6759b6ac3347c5be5b2360e9efa6d to your computer and use it in GitHub Desktop.
train val split
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"from glob import glob\n",
"import os\n",
"import shutil\n",
"import random\n",
"random.seed(42)\n",
"from tqdm.auto import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"dataset_dir = './combined/'\n",
"ext = '*'\n",
"minimum = 15000\n",
"split = 0.8\n",
"destination_dir = './bottom_type_dataset/'\n",
"train_dir = 'train'\n",
"val_dir = 'val'\n",
"delete_old = True\n",
"\n",
"if delete_old:\n",
" shutil.rmtree(os.path.join(destination_dir, train_dir), ignore_errors=True)\n",
" shutil.rmtree(os.path.join(destination_dir, val_dir), ignore_errors=True)\n",
" \n",
"os.makedirs(os.path.join(destination_dir, train_dir), exist_ok=True)\n",
"os.makedirs(os.path.join(destination_dir, val_dir), exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"classes = [os.path.dirname(d).split('/')[-1] for d in sorted(glob(dataset_dir+ '*/'))]\n",
"if minimum == -1:\n",
" minimum = min([len(glob(os.path.join(dataset_dir, c)+ \"/*.\"+ ext)) for c in classes])\n"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f36bf2b1c6824d919ce97d6b11ad7f11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Pants', max=15000, style=ProgressStyle(description_width='ini…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57f9461942b14e10805435d6cb9a9688",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Shorts', max=13264, style=ProgressStyle(description_width='in…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bd83bfcebd8a477d8d43c7ccccec701c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Skirt', max=15000, style=ProgressStyle(description_width='ini…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dff6b7f77c824af887b55e85a49578af",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Swimwear', max=8046, style=ProgressStyle(description_width='i…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"\n",
"for cls in classes:\n",
" files = glob(os.path.join(dataset_dir, cls)+ \"/*.\"+ ext)\n",
" random.shuffle(files)\n",
" end_point = minimum if len(files)> minimum else len(files)\n",
" split_point = int(end_point * split)\n",
" \n",
" train_files = files[:split_point]\n",
" val_files = files[split_point:end_point]\n",
" os.makedirs(os.path.join(destination_dir, train_dir, cls), exist_ok=True)\n",
" with tqdm(total=end_point, desc=cls) as pbar:\n",
" for train_file in train_files:\n",
" shutil.copyfile(train_file, os.path.join(destination_dir, train_dir, cls, os.path.basename(train_file)))\n",
" pbar.update()\n",
"\n",
" os.makedirs(os.path.join(destination_dir, val_dir, cls), exist_ok=True) \n",
" for val_file in val_files:\n",
" shutil.copyfile(val_file, os.path.join(destination_dir,val_dir, cls, os.path.basename(val_file)))\n",
" pbar.update()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python 3 tf2",
"language": "python",
"name": "venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment