Last active
September 4, 2021 18:10
-
-
Save harpone/1ce4c775ff63e22bc5228c4c77b48604 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GCSDataset(Dataset): | |
"""Generic PyTorch dataset for GCS. Streams data from GCS and (optionally) caches to local disk. | |
""" | |
def __init__(self, | |
bucketname=None, | |
path_list=None, # TODO: list bucket/path contents if None | |
target_list=None, | |
transform=None, | |
target_transform=None, | |
cache_data=True, | |
cache_path="/data"): | |
"""Custom dataset for GCS. | |
:param path_list: list or numpy array of `path`s where path is of form 'container/.../image_filename.jpg' | |
:param target_list: list, numpy array or dict of targets corresponding to each path in `path_list` | |
:param transform: | |
:param target_transform: | |
:param cache_data: whether or not cache data locally to enable faster read access once entire dataset has | |
been cached. | |
:param | |
""" | |
self.path_list = path_list | |
self.target_list = target_list | |
self.transform = transform | |
self.target_transform = target_transform | |
self.cache_data = cache_data | |
self.cache_path = cache_path | |
os.makedirs(self.cache_path, exist_ok=True) if self.cache_data else None | |
self.bucketname = bucketname | |
self.store = None | |
self.bucket = None | |
def start_bundle(self): | |
"""This is required for multiprocess distributed training since multiprocessing can't pickle | |
storage.Client() objects, see here: | |
https://github.com/googleapis/google-cloud-python/issues/3191 | |
Also here: https://stackoverflow.com/a/59043240/742616 | |
The method will be run the first time __getitem__ is called. | |
:return: | |
""" | |
self.store = storage.Client() | |
self.bucket = self.store.bucket(self.bucketname) # TODO retry policy etc? | |
def __len__(self): | |
return len(self.path_list) | |
def __getitem__(self, index): | |
if self.store is None: # instantiate storage clients | |
self.start_bundle() | |
blob_path = self.path_list[index] # e.g. 'imagenet/train/xxxx.jpg | |
if isinstance(blob_path, bytes): | |
blob_path = str(blob_path, "utf-8") | |
local_path, filename = utils.split_by_at(blob_path, "/", -1) | |
cache_folder = join(self.cache_path, local_path) | |
cache_file_path = join(cache_folder, filename) | |
os.makedirs(cache_folder, exist_ok=True) if self.cache_data else None | |
ext = filename.split(".")[-1].lower() | |
target = self.target_list[index] if self.target_list is not None else None | |
if not os.path.exists(cache_file_path): # from blob storage if no local file | |
try: | |
blob = self.bucket.blob(blob_path) | |
img_bytes = blob.download_as_string() | |
stream = io.BytesIO(img_bytes) | |
except Exception as e: # TODO: handle exceptions | |
print( | |
"cache_file_path:", | |
cache_file_path, | |
"blob_path:", | |
blob_path, | |
"local_path:", | |
local_path, | |
"filename:", | |
filename, | |
) | |
print("**** MISSING BLOB: " + blob_path) | |
print(e) | |
return None, None # collate_fn will handle missing examples | |
else: | |
stream = None | |
try: | |
if ext in ["jpeg", "jpg", "png", "gif"]: # TODO: not cool | |
if stream is not None: | |
img = Image.open(stream) | |
else: | |
img = Image.open(cache_file_path) | |
img = img.convert("RGB") # ensure RGB *BEFORE* saving | |
if self.cache_data and stream is not None: # save if using cache | |
img.save(cache_file_path) | |
else: # TODO: support other file extensions/ types | |
print("**** CANNOT HANDLE FILETYPE: " + blob_path) | |
return None, None # collate_fn will handle missing examples | |
stream.close() if stream is not None else None | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
except Exception as e: # TODO: catch correct exceptions | |
print(f"Exception while attempting to read image: {e}") | |
return None, None # collate_fn will handle missing examples | |
return img, target |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment