Created
September 20, 2020 08:14
-
-
Save mayankgrwl97/34e6ef1091881cb2045bf9aa7dbbf382 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import glob | |
import multiprocessing | |
import os | |
from functools import partial | |
import cv2 | |
from tqdm import tqdm | |
def png_to_jpg(img_png_path, jpg_dir): | |
img = cv2.imread(img_png_path) | |
img_basename = os.path.splitext(os.path.basename(img_png_path))[0] | |
img_jpg_path = os.path.join(jpg_dir, img_basename+'.jpg') | |
cv2.imwrite(img_jpg_path, img) | |
return img_jpg_path | |
def process(img_png_paths, jpg_dir, n_workers): | |
png_to_jpg_fn = partial(png_to_jpg, jpg_dir=jpg_dir) | |
with multiprocessing.Pool(n_workers) as pool: | |
for img_jpg_path in tqdm(pool.imap_unordered(png_to_jpg_fn, img_png_paths)) | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--png_dir', type=str, required=True) | |
parser.add_argument('--jpg_dir', type=str, required=True) | |
parser.add_argument('--n_workers', type=int, default=1) | |
args = parser.parse_args() | |
assert os.path.exists(args.png_dir) and os.path.isdir(args.png_dir) | |
os.makedirs(args.jpg_dir, exist_ok=True) | |
return args | |
if __name__ == '__main__': | |
args = get_args() | |
img_png_paths = sorted(glob.glob(os.path.join(args.png_dir, '*.png'))) | |
process(img_png_paths, args.jpg_dir, args.n_workers) |
Here's my version of a similar script using process_map
, along with some bells and whistles like resume, debug and profiling capabilities.
import os
import sys
import time
import argparse
from PIL import Image
from tqdm import tqdm
from itertools import repeat
from tqdm.contrib.concurrent import process_map
def convert(src_path, dst_path, index=None, debug=False, ext='png'):
start = time.time()
im = Image.open(src_path)
end = time.time()
diff = end - start
if debug:
tqdm.write(f'Loaded \t {index}: \t {diff:04.2f}s \t{src_path}')
start = time.time()
try:
im.save(dst_path, ext)
except Exception as e:
tqdm.write(F'Failed \t {index}: \t {src_path} \t {str(e)}')
return
end = time.time()
diff = end - start
if debug:
tqdm.write(f'Saved \t {index}: \t {diff:04.2f}s \t{dst_path}')
def preprocess(src_dir, dst_dir):
src_paths = []
dst_paths = []
for src_root, dirs, files in os.walk(src_dir, topdown=False):
if args.debug:
print(f'Scanning: \t {src_root}')
for idx, src_file in enumerate(files):
if not args.debug:
print(f'Scanning: \t {idx + 1}/{len(files)} \t {src_root}', end='\x1b[1K\r')
name, ext = os.path.splitext(src_file)
src_path = os.path.join(src_root, src_file)
if ext.lower() in ['.tif', '.tiff']:
dst_file = name + '.' + args.format
dst_root = src_root.replace(src_dir, dst_dir)
dst_path = os.path.join(dst_root, dst_file)
if os.path.isfile(dst_path):
try:
if args.debug:
print(f'Verifying: \t {dst_path}')
img = Image.open(dst_path)
img.verify()
continue
except:
pass
if not os.path.isdir(dst_root):
os.makedirs(dst_root)
src_paths.append(src_path)
dst_paths.append(dst_path)
if args.debug:
print(f'Found {len(src_paths)} image files. Processing . . .')
return src_paths, dst_paths
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--src-dir', required=True, help='source directory to look for images')
parser.add_argument('--dst-dir', required=True, help='destination directory to store images')
parser.add_argument('--format', default='jpeg', choices=['jpeg', 'png'], help='image format to save')
parser.add_argument('--debug', action='store_true', help='run the script in debug mode')
args = parser.parse_args()
src_paths, dst_paths = preprocess(args.src_dir, args.dst_dir)
indices = list(range(1, len(src_paths) + 1))
# Sequential
if args.debug:
for src_path, dst_path, index in zip(src_paths, dst_paths, indices):
convert(src_path, dst_path, index, args.debug, args.format)
# Parallel
else:
process_map(convert, src_paths, dst_paths, indices, repeat(args.debug),
repeat(args.format), desc='Processing: ', chunksize=1)
In order to run multiprocessing on a torch model, replace
from multiprocessing import Process, Pool
with
from torch.multiprocessing import Pool, Process, set_start_method
try:
set_start_method('spawn')
except RuntimeError:
pass
Reference: https://stackoverflow.com/questions/48822463/how-to-use-pytorch-multiprocessing
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can practically get rid of the
process
function by usingprocess_map
fromtqdm.contrib.concurrent
.As an added bonus, it also does the job of co-ordinating a nice progress bar for you.