Last active
February 8, 2022 02:39
-
-
Save LiutongZhou/ae6300651227b1714bd0d3d007428733 to your computer and use it in GitHub Desktop.
S3Client
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | |
# SPDX-License-Identifier: LicenseRef-.amazon.com.-AmznSL-1.0 | |
# Licensed under the Amazon Software License http://aws.amazon.com/asl/ | |
"""S3 Client | |
A helper class which wraps boto3's s3 service | |
""" | |
__author__ = "Liutong Zhou" | |
__version__ = "v2.6" | |
try: | |
import boto3 | |
except (ModuleNotFoundError, NameError): | |
r = input("boto3 is not found. Install boto3 from conda? (Y/N)") | |
if r.lower() == "y": | |
import os | |
os.system("conda install boto3 -y") | |
import boto3 | |
else: | |
import sys | |
sys.exit() | |
from concurrent.futures import ThreadPoolExecutor | |
from functools import partial | |
from multiprocessing import cpu_count | |
from pathlib import Path, PurePosixPath | |
from typing import Generator, List, Optional, Set, Tuple | |
from urllib.parse import urlparse | |
import botocore | |
from tqdm import tqdm | |
def s3_path_join(*args) -> str: | |
"""Returns the arguments joined by a slash ("/"), similarly to ``os.path.join()`` (on Unix). | |
If the first argument is "s3://", then that is preserved. | |
Parameters | |
-------- | |
*args : List[str] | |
The strings to join with a slash. | |
Returns | |
-------- | |
s3_uri : str | |
The joined s3 uri. | |
""" | |
if args[0].startswith("s3://"): | |
path = str(PurePosixPath(*args[1:])).lstrip("/") | |
return str(PurePosixPath(args[0], path)).replace("s3:/", "s3://") | |
return str(PurePosixPath(*args)).lstrip("/") | |
def get_bucket_name_key(s3_uri: str) -> Tuple[str, str]: | |
"""Parse the s3_uri and return Bucket name and object key | |
Parameters | |
---------- | |
s3_uri : str | |
Returns | |
--------- | |
bucket_name : str | |
key : str | |
""" | |
s3_uri_parsed = urlparse(s3_uri) | |
bucket_name = s3_uri_parsed.netloc | |
key = s3_uri.partition(f"s3://{bucket_name}")[-1].lstrip("/") | |
return bucket_name, key | |
class S3Client: | |
"""S3 client | |
Parameters | |
---------- | |
endpoint_url : str, optional | |
access_key_id : str, optional | |
secret_access_key : str, optional | |
kwargs : dict | |
""" | |
def __init__( | |
self, endpoint_url=None, access_key_id=None, secret_access_key=None, **kwargs | |
): | |
client = boto3.client( | |
service_name="s3", | |
endpoint_url=endpoint_url, | |
aws_access_key_id=access_key_id, | |
aws_secret_access_key=secret_access_key, | |
config=botocore.client.Config(max_pool_connections=2 * cpu_count()), | |
**kwargs, | |
) | |
self.__client = client | |
def create_bucket(self, bucket_name: str): | |
"""Create a bucket""" | |
self.__client.create_bucket(Bucket=bucket_name) | |
def create_folder(self, bucket_name: str, folder: str): | |
"""Create a folder under bucket. Do nothing if folder already exists""" | |
folder = str(PurePosixPath(folder)) + "/" | |
self.__client.put_object(Bucket=bucket_name, Key=folder) | |
def list_buckets(self) -> List[str]: | |
"""Return a list of bucket names""" | |
response = self.__client.list_buckets() | |
return [bucket["Name"] for bucket in response["Buckets"]] | |
def list_objects(self, bucket_name: str, folder: Optional[str] = None) -> List[str]: | |
"""Return a list of object keys that start with the key of the folder""" | |
objects_list = [] | |
paginator_list_objects_v2 = self.__client.get_paginator("list_objects_v2") | |
if folder: | |
response_iter = paginator_list_objects_v2.paginate( | |
Bucket=bucket_name, Prefix=folder | |
) | |
else: | |
response_iter = paginator_list_objects_v2.paginate(Bucket=bucket_name) | |
for response in response_iter: | |
contents = response.get("Contents", []) | |
if contents: | |
objects_list.extend(map(lambda content: content["Key"], contents)) | |
return objects_list | |
def yield_object_keys( | |
self, | |
bucket_name: str, | |
folder: Optional[str] = None, | |
prefix: Optional[str] = None, | |
suffix: Optional[str] = None, | |
) -> Generator[str, None, None]: | |
"""Yield object keys that starts with the key of the folder | |
Parameters | |
--------- | |
bucket_name : str | |
S3 bucket name | |
folder : str | |
key to the folder | |
prefix : str, optional, | |
If not None, only yield object keys that starts with the specified prefix. | |
Default None | |
suffix : str, optional | |
If not None, only yield object keys that ends with the specified suffix. | |
Default None | |
Yields | |
--------- | |
key : str | |
object key | |
""" | |
paginator_list_objects_v2 = self.__client.get_paginator("list_objects_v2") | |
if folder: | |
response_iter = paginator_list_objects_v2.paginate( | |
Bucket=bucket_name, Prefix=folder | |
) | |
else: | |
response_iter = paginator_list_objects_v2.paginate(Bucket=bucket_name) | |
for response in response_iter: | |
contents = response.get("Contents", []) | |
for content in contents: | |
key = content["Key"] | |
if prefix and not key.startswith(prefix): | |
continue | |
if suffix and not key.endswith(suffix): | |
continue | |
yield key | |
def yield_batch_uris( | |
self, | |
source_folder: str, | |
batch_size: int = 1024, | |
prefix: Optional[str] = None, | |
suffix: Optional[str] = None, | |
exclude_stems: Set[str] = frozenset(), | |
) -> Generator[List[str], None, None]: | |
"""Yield list of s3uris | |
Parameters | |
--------- | |
source_folder : str | |
an s3 uri to a folder | |
batch_size : int | |
prefix : str, optional | |
If not None, yield lists of s3_uris that starts with the specified prefix. | |
Default None | |
suffix : str, optional | |
If not None, yield lists of s3_uris that ends with the specified suffix. | |
Default None | |
exclude_stems : Set[str] | |
If not empty, s3_uri with the same file_path.stem in the exclude_stems set wont appear in | |
the yielded lists | |
Yields | |
--------- | |
s3_uris : List[str] | |
""" | |
assert source_folder.startswith( | |
"s3://" | |
), f"{source_folder} is not a valid s3 uri" | |
bucket_name, folder_key = get_bucket_name_key(source_folder) | |
key_gen = self.yield_object_keys( | |
bucket_name, folder_key, prefix=prefix, suffix=suffix | |
) | |
batch_uris = [] | |
for key in key_gen: | |
s3_uri = f"s3://{bucket_name}/" + key | |
if len(batch_uris) < batch_size: | |
if Path(s3_uri).stem not in exclude_stems: | |
batch_uris.append(s3_uri) | |
else: | |
yield batch_uris | |
batch_uris = [s3_uri] | |
if batch_uris: | |
yield batch_uris | |
def upload_file(self, file_name: str, bucket_name: str, destination: str, **kwargs): | |
"""Upload a local file to S3 bucket | |
Parameters | |
---------- | |
file_name : str | |
bucket_name : str | |
destination : str | |
the relative file path in the bucket | |
kwargs : dict | |
Examples | |
--------- | |
>>> client = S3Client() | |
>>> client.upload_file("./hello.txt","my_bucket","test/hello.txt") | |
""" | |
self.__client.upload_file(file_name, bucket_name, destination, **kwargs) | |
def upload_folder( | |
self, folder: str, bucket_name: str, destination_folder: str, **kwargs | |
): | |
"""Upload all items under local folder to the destination folder in s3 bucket | |
Parameters | |
---------- | |
folder : str | |
bucket_name : str | |
destination_folder : str | |
kwargs : dict | |
Examples | |
--------- | |
>>> client = S3Client() | |
>>> client.upload_folder("./local_folder","my_bucket","documents") | |
""" | |
object_paths = list(Path(folder).rglob("*")) | |
def _upload_object(object_path: Path): | |
"""Upload a single local object to destination""" | |
relative_path = object_path.relative_to(folder) | |
key = str(PurePosixPath(destination_folder, str(relative_path))) | |
if object_path.is_dir(): | |
self.create_folder(bucket_name, key) | |
else: | |
self.upload_file(str(object_path), bucket_name, key, **kwargs) | |
n_objects = len(object_paths) | |
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as executor: | |
jobs = executor.map(_upload_object, object_paths) | |
list(tqdm(jobs, desc="Uploading...", total=n_objects)) | |
def download_file(self, bucket_name: str, object_name: str, local_file_name: str): | |
"""Download a file from S3 bucket to local disk | |
Examples | |
--------- | |
>>> client = S3Client() | |
>>> client.download_file("my_bucket","remote_folder/file.txt","local_folder/file.txt") | |
""" | |
self.__client.download_file(bucket_name, object_name, local_file_name) | |
def download_s3_uris(self, s3_uris: List[str], to_folder="download"): | |
"""Downloading a list of s3 objects to the target folder | |
Parameters | |
--------- | |
s3_uris : str | |
full s3 uri, e.g. s3://bucket_name/object_key | |
to_folder : str | |
folder under which objects will be downloaded | |
""" | |
local_path = Path(to_folder) | |
local_path.mkdir(exist_ok=True, parents=True) | |
def download_uri(s3_uri: str): | |
assert s3_uri.lower().startswith("s3"), f"{s3_uri} is not a valid s3 uri" | |
bucket_name, key = get_bucket_name_key(s3_uri) | |
self.download_file(bucket_name, key, str(local_path / Path(s3_uri).name)) | |
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as executor: | |
jobs = executor.map(download_uri, s3_uris) | |
list( | |
tqdm( | |
jobs, | |
desc="Downloading files...", | |
total=len(s3_uris), | |
mininterval=5, | |
maxinterval=15, | |
) | |
) | |
def download_folder(self, bucket_name: str, remote_folder: str, destination: str): | |
"""Download a remote folder in S3 bucket to a local folder""" | |
# list all object keys that start with remote_folder | |
object_keys = self.list_objects(bucket_name, remote_folder) | |
def download_object(key: str): | |
"""Download a single object""" | |
relative_path = PurePosixPath(key).relative_to(remote_folder) | |
local_path = Path(destination, relative_path) | |
if key.endswith("/"): # if object is a directory | |
local_path.mkdir(exist_ok=True, parents=True) | |
else: # elif object is a file | |
local_path.parent.mkdir(exist_ok=True, parents=True) | |
self.download_file(bucket_name, key, str(local_path)) | |
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as executor: | |
jobs = executor.map(download_object, object_keys) | |
n_jobs = len(object_keys) | |
list( | |
tqdm( | |
jobs, | |
desc="Downloading...", | |
total=n_jobs, | |
mininterval=5, | |
maxinterval=15, | |
) | |
) | |
def generate_url(self, bucket: str, key: str, expire: int = 3600): | |
"""Generate a download url for sharing a single file | |
Parameters | |
--------- | |
bucket : str | |
bucket name | |
key : str | |
key (path) to the file object | |
expire : int, default : 3600 | |
Seconds after which the generated url will expire | |
Returns | |
--------- | |
url : str | |
""" | |
url = self.__client.generate_presigned_url( | |
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=expire | |
) | |
return url | |
def generate_urls( | |
self, bucket: str, folder: str = None, expire: int = 3600 | |
) -> List[str]: | |
"""Recursively generate urls for each file under the folder | |
Parameters | |
--------- | |
bucket : str | |
bucket name | |
folder : str, optional, Default : None | |
path to folder. If None, generate urls for all objects under the bucket | |
expire : int, optional, default : 3600 | |
Seconds after which the generated url will expire | |
Returns | |
--------- | |
urls : List[str] | |
""" | |
paths = self.list_objects(bucket, folder) | |
return list( | |
map( | |
lambda path: self.generate_url(bucket, path, expire), | |
filter(lambda path: not path.endswith("/"), paths), | |
) | |
) | |
def stream_objects(self, s3_uris: List[str]) -> Generator: | |
"""Yields file objects that can be read given a list of full S3 uris. | |
Parameters | |
--------- | |
s3_uris : List[str] | |
""" | |
get_object = self.__client.get_object | |
for s3_uri in s3_uris: | |
assert s3_uri.lower().startswith("s3"), f"{s3_uri} is not a valid s3 uri" | |
bucket_name, key = get_bucket_name_key(s3_uri) | |
response = get_object(Bucket=bucket_name, Key=key) | |
yield response["Body"] | |
def copy(self, source: List[str], destination: List[str]): | |
"""Copy objects from source bucket to destination bucket | |
Parameters | |
--------- | |
source : List[str] | |
a list of s3 uris | |
destination : List[str] | |
a list of s3 uris | |
""" | |
_copy = self.__client.copy | |
def __copy(src_bucket_key_pair, dest_bucket_key_pair): | |
src_bucket, src_key = src_bucket_key_pair | |
dest_bucket, dest_key = dest_bucket_key_pair | |
_copy({"Bucket": src_bucket, "Key": src_key}, dest_bucket, dest_key) | |
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as executor: | |
jobs = executor.map( | |
__copy, | |
map(get_bucket_name_key, source), | |
map(get_bucket_name_key, destination), | |
) | |
n_jobs = len(source) | |
list( | |
tqdm( | |
jobs, | |
desc="Copying...", | |
total=n_jobs, | |
mininterval=5, | |
maxinterval=15, | |
) | |
) | |
def delete(self, bucket: str, keys: List[str], **kwargs): | |
"""Delete a list of objects from bucket""" | |
max_objects_per_request = 1000 | |
delete_objects = partial(self.__client.delete_objects, Bucket=bucket, **kwargs) | |
def batch_key_gen(): | |
key_gen = iter(keys) | |
batch = [ | |
{"Key": key} for key, _ in zip(key_gen, range(max_objects_per_request)) | |
] | |
while batch: | |
yield batch | |
batch = [ | |
{"Key": key} | |
for key, _ in zip(key_gen, range(max_objects_per_request)) | |
] | |
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as pool: | |
for batch_keys in batch_key_gen(): | |
pool.submit(delete_objects, Delete={"Objects": batch_keys}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment