Created
April 12, 2023 05:54
-
-
Save jihunchoi/d85ac30ba9f8ab20bf076d84507fc592 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import List, Optional | |
import boto3 | |
from hydra.core.object_type import ObjectType | |
from hydra.plugins.config_source import ConfigResult, ConfigSource | |
from omegaconf import OmegaConf | |
from smart_open import open, parse_uri | |
class S3ConfigSource(ConfigSource): | |
""" | |
Hydra ConfigSource plugin that adds support for loading | |
config files from S3. | |
""" | |
def __init__(self, provider: str, path: str) -> None: | |
if not path.endswith("/"): | |
path = path + "/" | |
super().__init__(provider=provider, path=path) | |
self._paths = self._list_s3_directory(s3_uri=self.full_path()) | |
@staticmethod | |
def scheme() -> str: | |
return "s3" | |
def load_config(self, config_path: str) -> ConfigResult: | |
normalized_config_path = self._normalize_file_name(config_path) | |
s3_uri = os.path.join(self.full_path(), normalized_config_path) | |
with open(s3_uri, "r", encoding="utf-8") as f: | |
header_text = f.read(512) | |
header = self._get_header_dict(header_text) | |
f.seek(0) | |
cfg = OmegaConf.load(f) | |
return ConfigResult( | |
config=cfg, | |
path=f"{self.scheme()}://{self.path}", | |
provider=self.provider, | |
header=header, | |
) | |
@staticmethod | |
def _list_s3_directory(s3_uri: str) -> list[str]: | |
""" | |
List relative object keys whose prefix is ``s3_uri``. | |
Args: | |
s3_uri: S3 URI starting with s3://. | |
Returns: | |
Relative object keys; suffixes after ``s3_uri``. | |
""" | |
if not s3_uri.endswith("/"): | |
s3_uri = s3_uri + "/" | |
s3_uri_parsed = parse_uri(s3_uri) | |
offset = len(s3_uri_parsed.key_id) | |
paths = [] | |
s3_client = boto3.client("s3") | |
paginator = s3_client.get_paginator("list_objects_v2") | |
pages = paginator.paginate( | |
Bucket=s3_uri_parsed.bucket_id, Prefix=s3_uri_parsed.key_id | |
) | |
for page in pages: | |
for obj in page["Contents"]: | |
rel_path = obj["Key"][offset:] | |
paths.append(rel_path) | |
return paths | |
def available(self) -> bool: | |
# If no config file exists in the given S3 path, return False | |
return bool(self._paths) | |
def is_group(self, config_path: str) -> bool: | |
if config_path == "": | |
print(f"is_group: {config_path}") | |
return True | |
for path in self._paths: | |
if path.rstrip("/") == config_path and path.endswith("/"): | |
print("is_group: {path}") | |
return True | |
print(f"not is_group: {config_path}") | |
return False | |
def is_config(self, config_path: str) -> bool: | |
for path in self._paths: | |
if path == config_path and not path.endswith("/"): | |
print(f"is_config: {path}") | |
return True | |
print(f"not is_config: {config_path}") | |
return False | |
def list(self, config_path: str, results_filter: Optional[ObjectType]) -> List[str]: | |
files: List[str] = [] | |
s3_uri = os.path.join(self.full_path(), config_path) | |
for file in self._list_s3_directory(s3_uri): | |
file_path = os.path.join(self.full_path(), config_path, file) | |
self._list_add_result( | |
files=files, | |
file_path=file_path, | |
file_name=file, | |
results_filter=results_filter, | |
) | |
return sorted(list(set(files))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from hydra.core.config_search_path import ConfigSearchPath | |
from hydra.plugins.search_path_plugin import SearchPathPlugin | |
class S3SearchPathPlugin(SearchPathPlugin): | |
def manipulate_search_path(self, search_path: ConfigSearchPath) -> None: | |
for el in search_path.get_path(): | |
idx = el.path.find("s3:/") | |
if idx != -1: | |
suffix = el.path[idx + len("s3:/") :] | |
s3_path = f"s3://{suffix}" | |
el.path = s3_path |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment