Skip to content

Instantly share code, notes, and snippets.

@bkj
Created March 14, 2018 15:44
Show Gist options
  • Save bkj/f531dcb8e911a7c864663e316fd431b8 to your computer and use it in GitHub Desktop.
Save bkj/f531dcb8e911a7c864663e316fd431b8 to your computer and use it in GitHub Desktop.
from time import time
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def do_time(loader, epochs=3):
t = time()
for epoch in range(epochs):
for _ in loader:
pass
return time() - t
# --
# MNIST
t0 = time()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
load_time = time() - t0
loader = DataLoader(dataset, batch_size=128)
run_time = do_time(loader)
print("load_time=%f | run_time=%f" % (load_time, run_time))
# new: load_time=2.166751 | run_time=9.074668
# old: load_time=0.025438 | run_time=15.250080
# --
# CIFAR10
t0 = time()
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
load_time = time() - t0
loader = DataLoader(dataset, batch_size=128)
run_time = do_time(loader)
print("load_time=%f | run_time=%f" % (load_time, run_time))
# new: load_time=4.028385 | run_time=9.271593
# --
# SVHN
# train
t0 = time()
dataset = datasets.SVHN(root='./data', split='train', download=True, transform=transforms.ToTensor())
load_time = time() - t0
loader = DataLoader(dataset, batch_size=128)
run_time = do_time(loader)
print("load_time=%f | run_time=%f" % (load_time, run_time))
# load_time=7.682482 | run_time=14.907479
# extra
t0 = time()
dataset = datasets.SVHN(root='./data', split='extra', download=True, transform=transforms.ToTensor())
load_time = time() - t0
loader = DataLoader(dataset, batch_size=128)
run_time = do_time(loader, epochs=1)
print("load_time=%f | run_time=%f" % (load_time, run_time))
# load_time=55.454198 | run_time=37.794980
# --
# STL10
# extra
t0 = time()
dataset = datasets.STL10(root='./data', split='train+unlabled', download=True, transform=transforms.ToTensor())
load_time = time() - t0
loader = DataLoader(dataset, batch_size=128)
run_time = do_time(loader, epochs=1)
print("load_time=%f | run_time=%f" % (load_time, run_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment