Last active
July 3, 2017 11:17
-
-
Save PatWie/a838b55436d146ad6e409e09b3e5c338 to your computer and use it in GitHub Desktop.
decode.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # -*- coding: UTF-8 -*- | |
| import cv2 | |
| import argparse | |
| import numpy as np | |
| from tensorpack import * | |
| """ | |
| python lmdb_speed.py --create --lmdb /scratch_shared/test.lmdb --size 10000 | |
| python lmdb_speed.py --benchmark --lmdb /scratch_shared/test.lmdb --proc 2 --size 10000 | |
| """ | |
| class FakeData(RNGDataFlow): | |
| """ Generate fake data of given shapes""" | |
| def __init__(self, shapes, size=1000, random=True, dtype='float32'): | |
| super(FakeData, self).__init__() | |
| self.shapes = shapes | |
| self._size = int(size) | |
| self.dtype = dtype | |
| def size(self): | |
| return self._size | |
| def get_data(self): | |
| for _ in range(self._size): | |
| yield [(self.rng.rand(*k) * 255.).astype(self.dtype) for k in self.shapes] | |
| class ImageEncode(MapDataComponent): | |
| def __init__(self, ds, mode='.jpg', dtype=np.uint8, index=0): | |
| def func(img): | |
| return np.asarray(bytearray(cv2.imencode(mode, img)[1].tostring()), dtype=dtype) | |
| super(ImageEncode, self).__init__(ds, func, index=index) | |
| class ImageDecode(MapDataComponent): | |
| def __init__(self, ds, mode='.jpg', dtype=np.uint8, index=0): | |
| def func(im_data): | |
| return cv2.imdecode(np.asarray(bytearray(im_data), dtype=dtype), cv2.IMREAD_COLOR) | |
| super(ImageDecode, self).__init__(ds, func, index=index) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--create', action='store_true', help='create lmdb') | |
| parser.add_argument('--benchmark', action='store_true', help='start benchmark') | |
| parser.add_argument('--lmdb', type=str, help='path to lmdb', required=True) | |
| parser.add_argument('--size', type=int, help='number of images', default=1000) | |
| parser.add_argument('--proc', type=int, help='number of processes', default=1) | |
| args = parser.parse_args() | |
| if args.create: | |
| ds = FakeData([[224, 224, 3], [64]], size=args.size, random=True, dtype='uint8') | |
| ds = ImageEncode(ds, index=0) | |
| dftools.dump_dataflow_to_lmdb(ds, args.lmdb) | |
| if args.benchmark: | |
| ds = LMDBDataPoint(args.lmdb, shuffle=False) | |
| ds = ImageDecode(ds, index=0) | |
| if args.proc > 1: | |
| ds = PrefetchDataZMQ(ds, args.proc) | |
| TestDataSpeed(ds, args.size).start_test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment