Created
December 25, 2023 11:13
-
-
Save Lyken17/b4d528e216c90846e587dde35d6b8315 to your computer and use it in GitHub Desktop.
tar dataset and imagefolder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
import tarfile | |
import json | |
import os, os.path as osp | |
from io import BytesIO | |
from PIL import Image, ImageFile | |
import hashlib | |
from torch.utils.data import Dataset, get_worker_info, ConcatDataset | |
from multiprocessing.pool import ThreadPool as Pool | |
try: # make torchvision optional | |
from torchvision.transforms.functional import to_tensor | |
except: | |
to_tensor = None | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
class UnexpectedEOFTarFile(tarfile.TarFile): | |
def _load(self): | |
"""Read through the entire archive file and look for readable | |
members. | |
""" | |
try: | |
while True: | |
tarinfo = self.next() | |
if tarinfo is None: | |
break | |
except tarfile.ReadError as e: | |
assert e.args[0] == "unexpected end of data" | |
self._loaded = True | |
class TarDataset(Dataset): | |
"""Dataset that supports Tar archives (uncompressed).""" | |
def __init__( | |
self, | |
archive, | |
transform=None, | |
is_valid_file=lambda m: m.isfile() | |
and m.name.lower().endswith((".png", ".jpg", ".jpeg")), | |
ignore_unexpected_eof=False, | |
cache_dir="~/.cache/tardataset", | |
): | |
self.transform = transform | |
self.archive = osp.realpath(osp.expanduser(archive)) | |
self.default_label = osp.realpath(osp.expanduser(archive)) | |
# open tar file. in a multiprocessing setting (e.g. DataLoader workers), we | |
# have to open one file handle per worker (stored as the tar_obj dict), since | |
# when the multiprocessing method is 'fork', the workers share this TarDataset. | |
# we want one file handle per worker because TarFile is not thread-safe. | |
worker = get_worker_info() | |
worker = worker.id if worker else None | |
self.tar_obj = {} # lazy init | |
# TODO: add a function hash | |
# create a hash to cache results | |
mtime = osp.getmtime(archive) | |
# fn_hash = hash(is_valid_file) | |
m = hashlib.sha256() | |
m.update(str(mtime).encode("utf-8")) | |
# m.update(str(fn_hash).encode("utf-8")) | |
uuid = m.hexdigest()[:7] | |
# print(mtime, uuid) | |
tar_info = osp.realpath(archive).replace("/", "-") + f"-{uuid}.json" | |
fpath = osp.join(osp.expanduser(cache_dir), tar_info) | |
if not osp.exists(fpath): | |
self.tar_obj = { | |
worker: tarfile.open(archive) | |
if ignore_unexpected_eof is False | |
else UnexpectedEOFTarFile.open(archive) | |
} | |
print(f"{osp.basename(archive)} preparing tar.getnames() ...") | |
self.all_members = self.tar_obj[worker].getmembers() | |
# also store references to the iterated samples (a subset of the above) | |
self.samples = [m.name for m in self.all_members if is_valid_file(m)] | |
os.makedirs(osp.dirname(fpath), exist_ok=True) | |
json.dump(self.samples, open(fpath, "w"), indent=2) | |
else: | |
print(f"loading cached tarinfo from {fpath}") | |
self.samples = json.load(open(fpath, "r")) | |
def __getitem__(self, index): | |
image = self.get_image(self.samples[index], pil=True) | |
image = image.convert("RGB") # if it's grayscale, convert to RGB | |
if self.transform: # apply any custom transforms | |
image = self.transform(image) | |
return image, self.default_label | |
def __len__(self): | |
return len(self.samples) | |
def get_image(self, name, pil=False): | |
image = Image.open(BytesIO(self.get_file(name).read())) | |
if pil: | |
return image | |
return to_tensor(image) | |
def get_text_file(self, name, encoding="utf-8"): | |
"""Read a text file from the Tar archive, returned as a string. | |
Args: | |
name (str): File name to retrieve. | |
encoding (str): Encoding of file, default is utf-8. | |
Returns: | |
str: Content of text file. | |
""" | |
return self.get_file(name).read().decode(encoding) | |
def get_file(self, name): | |
"""Read an arbitrary file from the Tar archive. | |
Args: | |
name (str): File name to retrieve. | |
Returns: | |
io.BufferedReader: Object used to read the file's content. | |
""" | |
# ensure a unique file handle per worker, in multiprocessing settings | |
worker = get_worker_info() | |
worker = worker.id if worker else None | |
if worker not in self.tar_obj: | |
self.tar_obj[worker] = tarfile.open(self.archive) | |
return self.tar_obj[worker].extractfile(name) | |
def __del__(self): | |
"""Close the TarFile file handles on exit.""" | |
for o in self.tar_obj.values(): | |
o.close() | |
def __getstate__(self): | |
"""Serialize without the TarFile references, for multiprocessing compatibility.""" | |
state = dict(self.__dict__) | |
state["tar_obj"] = {} | |
return state | |
class TarImageFolder(Dataset): | |
"""Dataset that supports Tar archives (uncompressed), with a folder per class. | |
Similarly to torchvision.datasets.ImageFolder, assumes that the images inside | |
the Tar archive are arranged in this way by default: | |
root/ | |
dog.tar | |
cat.tar | |
... | |
bird.tar | |
where | |
dog.tar/ | |
xxx.png | |
xxy.png | |
[...]/xxz.png | |
cat.tar/ | |
123.png | |
nsdf3.png | |
[...]/asd932_.png | |
""" | |
def __init__( | |
self, | |
root, | |
transform=None, | |
max_loads=None, | |
is_valid_file=lambda m: m.isfile() | |
and m.name.lower().endswith((".png", ".jpg", ".jpeg")), | |
pool_size=16, | |
): | |
# load the archive meta information, and filter the samples | |
super().__init__() | |
root = osp.expanduser(root) | |
self.transform = transform | |
# assign a label to each image, based on its top-level folder name | |
self.class_to_idx = {} | |
self.targets = [] | |
tarfs = sorted(glob.glob(osp.join(root, "*.tar"))) | |
if max_loads is not None: | |
tarfs = tarfs[: min(max_loads, len(tarfs))] | |
print(tarfs) | |
# tar_dst_list = [TarDataset(tar_path) for tar_path in tarfs] | |
self.class2idx = {} | |
for tar_fpath in tarfs: | |
self.class2idx[tar_fpath] = len(self.class2idx.keys()) | |
print(self.class2idx) | |
# parallel loading | |
def worker(tar_path): | |
return TarDataset(tar_path) | |
pool = Pool(pool_size) | |
jobs = [] | |
for tar_path in tarfs: | |
jobs.append(pool.apply_async(worker, (tar_path,))) | |
pool.close() | |
pool.join() | |
tar_dst_list = [_.get() for _ in jobs] | |
self.dataset = ConcatDataset(tar_dst_list) | |
print("TarImageFolder dataset init finish") | |
if len(self.class2idx) == 0: | |
raise IOError( | |
"No classes (top-level folders) were found with the given criteria. The given\n" | |
"extensions, is_valid_file is too strict, or the archive is empty." | |
) | |
# the inverse mapping is often useful | |
self.idx2class = {v: k for k, v in self.class2idx.items()} | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, index): | |
image, _label = self.dataset[index] | |
if self.transform: # apply any custom transforms | |
image = self.transform(image) | |
label = self.class2idx[_label] | |
return (image, label) | |
if __name__ == "__main__": | |
# img = TarDataset( | |
# "~/datasets/sam-00-50/full-sam/sa_000999.tar" | |
# ) | |
# print("init finish, try to fetch images") | |
# print(img[0]) | |
dst = TarImageFolder( | |
"~/datasets/sam-00-50/full-sam", | |
max_loads=20, | |
) | |
print("init finish, try to fetch images") | |
for idx, (image, label) in enumerate(dst): | |
print(image, label) | |
if idx > 100: | |
break | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment