|
import argparse |
|
import glob |
|
import os |
|
import time |
|
import numpy as np |
|
from torch.utils.data import DataLoader |
|
from pfio.cache import MultiprocessFileCache |
|
from pfio.cache import FileCache |
|
|
|
|
|
class CachedDataset: |
|
def __init__(self, cache): |
|
self.cache = cache |
|
|
|
def __len__(self): |
|
return len(self.cache) |
|
|
|
def __getitem__(self, idx): |
|
return self.cache.get(idx) |
|
|
|
|
|
def main(): |
|
args = argparse.ArgumentParser() |
|
args.add_argument('--cache-dir', default='/tmp') |
|
args = args.parse_args() |
|
|
|
cache_dir = args.cache_dir |
|
|
|
num_workers = [16, 32, 64, 128] |
|
n_trials = 5 |
|
all_N_l = [ |
|
(32768, 1024 ** 2), |
|
(1024 ** 2, 32768), |
|
] |
|
|
|
print('| # samples | sample size | # workers | mean time per sample (us) | stddev (us) |') |
|
print('|:---|:---|:---|:---|:---|') |
|
for i, (N, l) in enumerate(all_N_l): |
|
# build the cache |
|
cache = FileCache(N, do_pickle=False, dir=cache_dir) |
|
for j in range(N): |
|
buf = np.random.bytes(l) |
|
cache.put(j, buf) |
|
cache.preserve('cache_data') |
|
|
|
# Load cache |
|
for n_worker in num_workers: |
|
cache = MultiprocessFileCache(N, do_pickle=False, |
|
dir=cache_dir) |
|
cache.preload('cache_data') |
|
ds = CachedDataset(cache) |
|
|
|
times = [] |
|
for _ in range(n_trials): |
|
loader = DataLoader(ds, collate_fn=lambda x: x, |
|
batch_size=128, |
|
num_workers=n_worker, shuffle=True) |
|
before = time.time() |
|
for samples in loader: |
|
assert all(len(s) == l for s in samples) |
|
after = time.time() |
|
times.append((after - before) / N) |
|
mean, std = np.mean(times), np.std(times) |
|
print('| {} | {} | {} | {:.2f} | {:.2f} |' |
|
.format(N, l, n_worker, 1e+6 * mean, 1e+6 * std)) |
|
|
|
for f in glob.glob('{}/cache_data*'.format(cache_dir)): |
|
os.remove(f) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |