Created
July 16, 2017 12:18
-
-
Save alantian/e374122d4dbbd7e35dca8cdb70657099 to your computer and use it in GitHub Desktop.
Convert dataset to lmdb in parallel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import os | |
import sys | |
from os.path import basename | |
import cv2 | |
import lmdb | |
import base64 | |
import glob | |
import time | |
import numpy as np | |
import io | |
import logging | |
from types import SimpleNamespace | |
from joblib import Parallel, delayed | |
def iter_filename(root_dir): | |
for filepath in glob.iglob('%s/*/*/*' % root_dir): | |
yield filepath | |
def convert_single_file(filepath): | |
filename = basename(filepath).split('.')[0] | |
try: | |
base = 512 | |
data = open(filepath, 'rb').read() | |
img = cv2.imdecode(np.frombuffer(data, np.uint8), 1) | |
if len(img.shape) == 2: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
if len(img.shape) > 2 and img.shape[2] == 4: | |
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) | |
w, h, ch = img.shape | |
if w > h: | |
nw = w * base // h | |
nh = base | |
else: | |
nw = base | |
nh = h * base // w | |
img = cv2.resize(img, (nh, nw), interpolation=cv2.INTER_LANCZOS4) | |
result, encimg = cv2.imencode(".jpg", img) | |
byte_encimg = encimg.tobytes() | |
byte_filename = filename.encode('ascii') | |
success = True | |
explain = '' | |
except Exception as e: | |
byte_encimg = None | |
byte_filename = None | |
success = False | |
explain = str(e.args) | |
return SimpleNamespace( | |
byte_encimg=byte_encimg, | |
byte_filename=byte_filename, | |
success=success, | |
filename=filename, | |
explain=explain) | |
def main(): | |
batch_size = 4096 | |
parser = argparse.ArgumentParser(description='Convert lmdb') | |
parser.add_argument( | |
'--root_dir', '-r', default='', help='Path to input root dir.') | |
parser.add_argument( | |
'--lmdb', '-l', default='', help='Path to output lmdb file.') | |
parser.add_argument( | |
'--logging_file', | |
'-g', | |
default='convert_lmdb.log', | |
help='Path to logging file.') | |
args = parser.parse_args() | |
print(args) | |
logger = logging.getLogger('convert_lmdb') | |
fh = logging.FileHandler(args.logging_file, 'a+') | |
logger.setLevel(10) | |
logger.addHandler(fh) | |
lmdb_env = lmdb.open(args.lmdb, map_size=int(1e12)) | |
count = {'total': 0, 'success': 0} | |
def deal_batch(batch): | |
res = Parallel(n_jobs=-1)(delayed(convert_single_file)(filename) | |
for filename in batch) | |
for converted in res: | |
if converted.success: | |
lmdb_txn = lmdb_env.begin(write=True) | |
lmdb_txn.put(converted.byte_filename, converted.byte_encimg) | |
lmdb_txn.commit() | |
count['success'] += 1 | |
else: | |
logger.error('"%s" failed with reason %s' % | |
(converted.filename, converted.explain)) | |
count['total'] += 1 | |
logger.info('Progress: %d files tried. %d successed.' % | |
(count['total'], count['success'])) | |
logger.info('start') | |
batch = [] | |
for index, filename in zip( | |
range(int(10**10)), iter_filename(args.root_dir)): | |
batch.append(filename) | |
if len(batch) == batch_size: | |
deal_batch(batch) | |
batch = [] | |
if len(batch) > 0: | |
deal_batch(batch) | |
batch = [] | |
lmdb_env.close() | |
logger.info('finish') | |
return | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment