Created
January 21, 2019 16:24
-
-
Save martinsotir/59731545e0a9a390b561dbf99e2b2d7b to your computer and use it in GitHub Desktop.
ImageZipDataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import tarfile | |
import zipfile | |
from pathlib import Path | |
from PIL import Image | |
from tqdm import tqdm | |
from torchvision import transforms | |
import mmap | |
import torch.multiprocessing as mp | |
import numpy as np | |
import os | |
# mp.set_start_method('forkserver', force=True) # On linux, you may want to use the fork method (this the default with pytorch 0.4) | |
class ImageZipDataset(Dataset): | |
def __init__(self, path, extension=".jpg", cache=False): | |
zip = zipfile.ZipFile(open(path, mode="rb")) | |
self._files = [m for m in zip.namelist() if m.endswith(extension)] | |
self._path = path | |
self._cache = cache | |
self._archive = None | |
self._pil_to_tensor = transforms.ToTensor() | |
self._memmap_handle = None | |
if mp.get_start_method() == 'fork' and self._cache is True: | |
self._memmap_handle = self.get_memmap_handle() | |
def get_memmap_handle(self): | |
if self._memmap_handle is not None: | |
return self._memmap_handle | |
else: | |
handle = open(self._path, "rb") | |
if os.name == 'nt': | |
memmap_handle = mmap.mmap(handle.fileno(), 0, self._path, access=mmap.ACCESS_READ) | |
else: | |
memmap_handle = mmap.mmap(handle.fileno(), 0, access=mmap.ACCESS_READ) | |
return memmap_handle | |
def get_zip_handle(self): | |
if self._archive is None: | |
if self._cache is True: | |
file_handle = self.get_memmap_handle() | |
else: | |
file_handle = open(self._path, "rb") | |
self._archive = zipfile.ZipFile(file_handle, mode="r") | |
return self._archive | |
def __getitem__(self, index): | |
#return torch.Tensor(list(self.get_zip_handle().read(self._files[index]))) # Test without jpeg decoding | |
img = Image.open(self.get_zip_handle().open(self._files[index]), mode="r") # use of jpeg-turbo is recommended | |
return self._pil_to_tensor(img) | |
def __len__(self): | |
return len(self._files) | |
def test(cache=True, shuffle=True, num_workers=2): | |
test = ZipDataset("myfiles.zip", cache=cache) | |
if mp.get_start_method() != 'fork' and cache == True : | |
handle = test.get_memmap_handle() # force having at least one handle of the memmaped file in spawn mode (not sure if done right?) | |
dl = DataLoader(test, batch_size=1, shuffle=shuffle, num_workers=num_workers) | |
res = 1 | |
for i in range(3): | |
for x in tqdm(dl): | |
res += x.sum() | |
if __name__ == "__main__": | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment