Skip to content

Instantly share code, notes, and snippets.

@PatWie
Created June 19, 2018 19:01
Show Gist options
  • Save PatWie/e4f29f2820b3aa60a8b6cf0a561758d6 to your computer and use it in GitHub Desktop.
Save PatWie/e4f29f2820b3aa60a8b6cf0a561758d6 to your computer and use it in GitHub Desktop.
database benchmark tensorpack
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Patrick Wieschollek <[email protected]>
from tensorpack import *
from tensorpack.dataflow.base import DataFlow
from tensorpack.dataflow.dftools import LMDBDataWriter, TFRecordDataWriter, NumpyDataWriter, HDF5DataWriter
from tensorpack.dataflow.format import LMDBDataReader, TFRecordDataReader, NumpyDataReader, HDF5DataReader
import os
import numpy as np
import time
def delete_file_if_exists(fn):
try:
os.remove(fn)
except OSError:
pass
class SeededFakeDataFlow(DataFlow):
"""docstring for SeededFakeDataFlow"""
def __init__(self, seed=42, size=32):
super(SeededFakeDataFlow, self).__init__()
self.seed = seed
self._size = size
self.cache = []
def reset_state(self):
np.random.seed(self.seed)
for _ in range(self._size):
label = np.random.randint(low=0, high=10)
img = np.random.randn(256, 256, 3)
self.cache.append([label, img])
def size(self):
return self._size
def get_data(self):
for dp in self.cache:
yield dp
if False:
ds = SeededFakeDataFlow(size=1000)
LMDBDataWriter(ds, 'tmp.lmdb').serialize()
TFRecordDataWriter(ds, 'tmp.tfrecord').serialize()
NumpyDataWriter(ds, 'tmp.npz').serialize()
HDF5DataWriter(ds, 'tmp.h5', ['label', 'images']).serialize()
"""
1,5G tmp.h5
1,5G tmp.lmdb
8,0K tmp.lmdb-lock
4,3G tmp.npz
3,0G tmp.tfrecord
"""
ds = LMDBDataReader('tmp.lmdb', shuffle=False)
TestDataSpeed(ds).start()
print('.............')
ds = TFRecordDataReader('tmp.tfrecord', 1000)
TestDataSpeed(ds).start()
"""
20%|##############2 |1000/5000[00:00<00:00,4488.33it/s]
.............
40%|############################4 |2000/5000[00:01<00:01,1672.89it/s]
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment