Skip to content

Instantly share code, notes, and snippets.

@harpone
Last active September 4, 2021 18:10
Show Gist options
  • Save harpone/1ce4c775ff63e22bc5228c4c77b48604 to your computer and use it in GitHub Desktop.
Save harpone/1ce4c775ff63e22bc5228c4c77b48604 to your computer and use it in GitHub Desktop.
class GCSDataset(Dataset):
"""Generic PyTorch dataset for GCS. Streams data from GCS and (optionally) caches to local disk.
"""
def __init__(self,
bucketname=None,
path_list=None, # TODO: list bucket/path contents if None
target_list=None,
transform=None,
target_transform=None,
cache_data=True,
cache_path="/data"):
"""Custom dataset for GCS.
:param path_list: list or numpy array of `path`s where path is of form 'container/.../image_filename.jpg'
:param target_list: list, numpy array or dict of targets corresponding to each path in `path_list`
:param transform:
:param target_transform:
:param cache_data: whether or not cache data locally to enable faster read access once entire dataset has
been cached.
:param
"""
self.path_list = path_list
self.target_list = target_list
self.transform = transform
self.target_transform = target_transform
self.cache_data = cache_data
self.cache_path = cache_path
os.makedirs(self.cache_path, exist_ok=True) if self.cache_data else None
self.bucketname = bucketname
self.store = None
self.bucket = None
def start_bundle(self):
"""This is required for multiprocess distributed training since multiprocessing can't pickle
storage.Client() objects, see here:
https://github.com/googleapis/google-cloud-python/issues/3191
Also here: https://stackoverflow.com/a/59043240/742616
The method will be run the first time __getitem__ is called.
:return:
"""
self.store = storage.Client()
self.bucket = self.store.bucket(self.bucketname) # TODO retry policy etc?
def __len__(self):
return len(self.path_list)
def __getitem__(self, index):
if self.store is None: # instantiate storage clients
self.start_bundle()
blob_path = self.path_list[index] # e.g. 'imagenet/train/xxxx.jpg
if isinstance(blob_path, bytes):
blob_path = str(blob_path, "utf-8")
local_path, filename = utils.split_by_at(blob_path, "/", -1)
cache_folder = join(self.cache_path, local_path)
cache_file_path = join(cache_folder, filename)
os.makedirs(cache_folder, exist_ok=True) if self.cache_data else None
ext = filename.split(".")[-1].lower()
target = self.target_list[index] if self.target_list is not None else None
if not os.path.exists(cache_file_path): # from blob storage if no local file
try:
blob = self.bucket.blob(blob_path)
img_bytes = blob.download_as_string()
stream = io.BytesIO(img_bytes)
except Exception as e: # TODO: handle exceptions
print(
"cache_file_path:",
cache_file_path,
"blob_path:",
blob_path,
"local_path:",
local_path,
"filename:",
filename,
)
print("**** MISSING BLOB: " + blob_path)
print(e)
return None, None # collate_fn will handle missing examples
else:
stream = None
try:
if ext in ["jpeg", "jpg", "png", "gif"]: # TODO: not cool
if stream is not None:
img = Image.open(stream)
else:
img = Image.open(cache_file_path)
img = img.convert("RGB") # ensure RGB *BEFORE* saving
if self.cache_data and stream is not None: # save if using cache
img.save(cache_file_path)
else: # TODO: support other file extensions/ types
print("**** CANNOT HANDLE FILETYPE: " + blob_path)
return None, None # collate_fn will handle missing examples
stream.close() if stream is not None else None
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
except Exception as e: # TODO: catch correct exceptions
print(f"Exception while attempting to read image: {e}")
return None, None # collate_fn will handle missing examples
return img, target
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment