Skip to content

Instantly share code, notes, and snippets.

@martinsotir
Created January 21, 2019 16:24
Show Gist options
  • Save martinsotir/59731545e0a9a390b561dbf99e2b2d7b to your computer and use it in GitHub Desktop.
Save martinsotir/59731545e0a9a390b561dbf99e2b2d7b to your computer and use it in GitHub Desktop.
ImageZipDataset
import torch
from torch.utils.data import Dataset, DataLoader
import tarfile
import zipfile
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
import mmap
import torch.multiprocessing as mp
import numpy as np
import os
# mp.set_start_method('forkserver', force=True) # On linux, you may want to use the fork method (this the default with pytorch 0.4)
class ImageZipDataset(Dataset):
def __init__(self, path, extension=".jpg", cache=False):
zip = zipfile.ZipFile(open(path, mode="rb"))
self._files = [m for m in zip.namelist() if m.endswith(extension)]
self._path = path
self._cache = cache
self._archive = None
self._pil_to_tensor = transforms.ToTensor()
self._memmap_handle = None
if mp.get_start_method() == 'fork' and self._cache is True:
self._memmap_handle = self.get_memmap_handle()
def get_memmap_handle(self):
if self._memmap_handle is not None:
return self._memmap_handle
else:
handle = open(self._path, "rb")
if os.name == 'nt':
memmap_handle = mmap.mmap(handle.fileno(), 0, self._path, access=mmap.ACCESS_READ)
else:
memmap_handle = mmap.mmap(handle.fileno(), 0, access=mmap.ACCESS_READ)
return memmap_handle
def get_zip_handle(self):
if self._archive is None:
if self._cache is True:
file_handle = self.get_memmap_handle()
else:
file_handle = open(self._path, "rb")
self._archive = zipfile.ZipFile(file_handle, mode="r")
return self._archive
def __getitem__(self, index):
#return torch.Tensor(list(self.get_zip_handle().read(self._files[index]))) # Test without jpeg decoding
img = Image.open(self.get_zip_handle().open(self._files[index]), mode="r") # use of jpeg-turbo is recommended
return self._pil_to_tensor(img)
def __len__(self):
return len(self._files)
def test(cache=True, shuffle=True, num_workers=2):
test = ZipDataset("myfiles.zip", cache=cache)
if mp.get_start_method() != 'fork' and cache == True :
handle = test.get_memmap_handle() # force having at least one handle of the memmaped file in spawn mode (not sure if done right?)
dl = DataLoader(test, batch_size=1, shuffle=shuffle, num_workers=num_workers)
res = 1
for i in range(3):
for x in tqdm(dl):
res += x.sum()
if __name__ == "__main__":
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment