Skip to content

Instantly share code, notes, and snippets.

@NISH1001
Created November 24, 2018 14:45
Show Gist options
  • Save NISH1001/8c3f5de2e7757a983cc1bf0ce0c57366 to your computer and use it in GitHub Desktop.
Save NISH1001/8c3f5de2e7757a983cc1bf0ce0c57366 to your computer and use it in GitHub Desktop.
Create train-test directories
#!/usr/bin/env python3
import glob
from collections import defaultdict
import shutil
import os
def create_dir(dirname):
if not os.path.isdir(dirname):
os.mkdir(dirname)
def get_pathmap(path, width=256, height=256):
root = os.path.abspath(path)
dirs = glob.glob(os.path.join(root, '*'))
pathmap = defaultdict(lambda : [])
imagemap = defaultdict(lambda : [])
for dr in dirs:
label = dr.split(os.sep)[-1]
print(label)
print("Loading images from :: {}".format(dr))
for imgpath in glob.glob(os.path.join(dr, '*')):
pathmap[label].append(imgpath)
return pathmap
def create_train_test(pathmap, train_dir, test_dir, split_ratio=0.8, random=False):
sep = os.sep
for label in pathmap:
print(label)
paths = pathmap[label]
curr_dir_train = os.path.join(train_dir, label)
curr_dir_test = os.path.join(test_dir, label)
create_dir(curr_dir_train)
create_dir(curr_dir_test)
n = len(paths)
j = int(n * 0.8)
k = n - j
print("Creating train...")
for path in paths[:j]:
print(path)
fname = os.path.split(path)[-1]
shutil.copy(path, curr_dir_train)
print("Creating test...")
for path in paths[j:]:
print(path)
fname = os.path.split(path)[-1]
shutil.copy(path, curr_dir_test)
def main():
pathmap = get_pathmap("/home/paradox/data/keydataset/custom/train/")
create_train_test(pathmap, '/home/paradox/data/keys/train', '/home/paradox/data/keys/test')
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment