Last active
June 14, 2022 16:45
-
-
Save jogardi/b7e5e820d0e041b51bda077476c81046 to your computer and use it in GitHub Desktop.
Example of managing resources and datasets for ML
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import joblib, os, torch | |
from functools import cached_property | |
from google.cloud import storage | |
from google.oauth2 import service_account | |
import torchvision.transforms as transforms | |
def prep_dir(path: str) -> str: | |
if not os.path.exists(path): | |
os.mkdir(path) | |
return path | |
class Resources: | |
project_dir = utils.prep_dir(os.path.expanduser("~/.<your project name>")) | |
cached_files = utils.prep_dir(f"{project_dir}/cached_files") | |
memcache = joblib.Memory(f"{cached_files}/joblibcache") | |
@cached_property | |
def gstorage_client(self) -> storage.Client: | |
return storage.Client(credentials=self.gcreds, project=self.gcreds.project_id) | |
@cached_property | |
def gcreds(self): | |
return service_account.Credentials.from_service_account_file(filename=f"{self.project_dir}/gcloud_service_credentials.json") | |
@cached_property | |
def res_bucket(self) -> storage.Bucket: | |
return self.gstorage_client.get_bucket("res1") | |
def from_gcp(self, key: str) -> str: | |
parts = key.split("/") | |
if len(parts) > 1: | |
utils.prep_dir(f"{self.bucket_files_dir}/{'/'.join(parts[:-1])}") | |
file_path = f"{self.bucket_files_dir}/{key}" | |
if not os.path.exists(file_path): | |
source_blob = self.res_bucket.blob(key) | |
if source_blob.exists(): | |
source_blob.download_to_filename(file_path) | |
return file_path | |
def from_gcp_dir(self, key: str, bucket=None): | |
if bucket is None: | |
bucket = self.res_bucket | |
file_path = f"{self.bucket_files_dir}/{'_'.join(key.split('/'))}" | |
if not os.path.exists(file_path): | |
os.makedirs(file_path) | |
blobs = bucket.list_blobs(prefix=key) | |
for blob in blobs: | |
subpath = f"{self.bucket_files_dir}/{blob.name}" | |
subpath_dir = "/".join(subpath.split("/")[:-1]) | |
if not os.path.exists(subpath_dir): | |
os.makedirs(subpath_dir) | |
blob.download_to_filename(subpath) | |
return file_path | |
def load_transformed_ex_dataset(self) -> torch.utils.data.Dataset: | |
return torchvision.datasets.ImageFolder( | |
self.ex_dataset_path, | |
transform=transforms.Compose([ | |
transforms.ToTensor() | |
]) | |
) | |
def load_transformed_ex_dataset_cropped(self) -> torch.utils.data.Dataset: | |
return torchvision.datasets.ImageFolder( | |
self.ex_dataset_path, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.CenterCrop(256) | |
]) | |
) | |
@cached_property | |
def ex_dataset_path(self) -> str: | |
return self.from_gcp("path to the file in your resources bucket") | |
project_res = Resources() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment