Skip to content

Instantly share code, notes, and snippets.

@PatWie
Last active July 3, 2017 11:17
Show Gist options
  • Select an option

  • Save PatWie/a838b55436d146ad6e409e09b3e5c338 to your computer and use it in GitHub Desktop.

Select an option

Save PatWie/a838b55436d146ad6e409e09b3e5c338 to your computer and use it in GitHub Desktop.
decode.py
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import cv2
import argparse
import numpy as np
from tensorpack import *
"""
python lmdb_speed.py --create --lmdb /scratch_shared/test.lmdb --size 10000
python lmdb_speed.py --benchmark --lmdb /scratch_shared/test.lmdb --proc 2 --size 10000
"""
class FakeData(RNGDataFlow):
""" Generate fake data of given shapes"""
def __init__(self, shapes, size=1000, random=True, dtype='float32'):
super(FakeData, self).__init__()
self.shapes = shapes
self._size = int(size)
self.dtype = dtype
def size(self):
return self._size
def get_data(self):
for _ in range(self._size):
yield [(self.rng.rand(*k) * 255.).astype(self.dtype) for k in self.shapes]
class ImageEncode(MapDataComponent):
def __init__(self, ds, mode='.jpg', dtype=np.uint8, index=0):
def func(img):
return np.asarray(bytearray(cv2.imencode(mode, img)[1].tostring()), dtype=dtype)
super(ImageEncode, self).__init__(ds, func, index=index)
class ImageDecode(MapDataComponent):
def __init__(self, ds, mode='.jpg', dtype=np.uint8, index=0):
def func(im_data):
return cv2.imdecode(np.asarray(bytearray(im_data), dtype=dtype), cv2.IMREAD_COLOR)
super(ImageDecode, self).__init__(ds, func, index=index)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--create', action='store_true', help='create lmdb')
parser.add_argument('--benchmark', action='store_true', help='start benchmark')
parser.add_argument('--lmdb', type=str, help='path to lmdb', required=True)
parser.add_argument('--size', type=int, help='number of images', default=1000)
parser.add_argument('--proc', type=int, help='number of processes', default=1)
args = parser.parse_args()
if args.create:
ds = FakeData([[224, 224, 3], [64]], size=args.size, random=True, dtype='uint8')
ds = ImageEncode(ds, index=0)
dftools.dump_dataflow_to_lmdb(ds, args.lmdb)
if args.benchmark:
ds = LMDBDataPoint(args.lmdb, shuffle=False)
ds = ImageDecode(ds, index=0)
if args.proc > 1:
ds = PrefetchDataZMQ(ds, args.proc)
TestDataSpeed(ds, args.size).start_test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment