Skip to content

Instantly share code, notes, and snippets.

@LiutongZhou
Last active February 8, 2022 02:39
Show Gist options
  • Save LiutongZhou/ae6300651227b1714bd0d3d007428733 to your computer and use it in GitHub Desktop.
Save LiutongZhou/ae6300651227b1714bd0d3d007428733 to your computer and use it in GitHub Desktop.
S3Client
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: LicenseRef-.amazon.com.-AmznSL-1.0
# Licensed under the Amazon Software License http://aws.amazon.com/asl/
"""S3 Client
A helper class which wraps boto3's s3 service
"""
__author__ = "Liutong Zhou"
__version__ = "v2.6"
try:
import boto3
except (ModuleNotFoundError, NameError):
r = input("boto3 is not found. Install boto3 from conda? (Y/N)")
if r.lower() == "y":
import os
os.system("conda install boto3 -y")
import boto3
else:
import sys
sys.exit()
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from multiprocessing import cpu_count
from pathlib import Path, PurePosixPath
from typing import Generator, List, Optional, Set, Tuple
from urllib.parse import urlparse
import botocore
from tqdm import tqdm
def s3_path_join(*args) -> str:
"""Returns the arguments joined by a slash ("/"), similarly to ``os.path.join()`` (on Unix).
If the first argument is "s3://", then that is preserved.
Parameters
--------
*args : List[str]
The strings to join with a slash.
Returns
--------
s3_uri : str
The joined s3 uri.
"""
if args[0].startswith("s3://"):
path = str(PurePosixPath(*args[1:])).lstrip("/")
return str(PurePosixPath(args[0], path)).replace("s3:/", "s3://")
return str(PurePosixPath(*args)).lstrip("/")
def get_bucket_name_key(s3_uri: str) -> Tuple[str, str]:
"""Parse the s3_uri and return Bucket name and object key
Parameters
----------
s3_uri : str
Returns
---------
bucket_name : str
key : str
"""
s3_uri_parsed = urlparse(s3_uri)
bucket_name = s3_uri_parsed.netloc
key = s3_uri.partition(f"s3://{bucket_name}")[-1].lstrip("/")
return bucket_name, key
class S3Client:
"""S3 client
Parameters
----------
endpoint_url : str, optional
access_key_id : str, optional
secret_access_key : str, optional
kwargs : dict
"""
def __init__(
self, endpoint_url=None, access_key_id=None, secret_access_key=None, **kwargs
):
client = boto3.client(
service_name="s3",
endpoint_url=endpoint_url,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
config=botocore.client.Config(max_pool_connections=2 * cpu_count()),
**kwargs,
)
self.__client = client
def create_bucket(self, bucket_name: str):
"""Create a bucket"""
self.__client.create_bucket(Bucket=bucket_name)
def create_folder(self, bucket_name: str, folder: str):
"""Create a folder under bucket. Do nothing if folder already exists"""
folder = str(PurePosixPath(folder)) + "/"
self.__client.put_object(Bucket=bucket_name, Key=folder)
def list_buckets(self) -> List[str]:
"""Return a list of bucket names"""
response = self.__client.list_buckets()
return [bucket["Name"] for bucket in response["Buckets"]]
def list_objects(self, bucket_name: str, folder: Optional[str] = None) -> List[str]:
"""Return a list of object keys that start with the key of the folder"""
objects_list = []
paginator_list_objects_v2 = self.__client.get_paginator("list_objects_v2")
if folder:
response_iter = paginator_list_objects_v2.paginate(
Bucket=bucket_name, Prefix=folder
)
else:
response_iter = paginator_list_objects_v2.paginate(Bucket=bucket_name)
for response in response_iter:
contents = response.get("Contents", [])
if contents:
objects_list.extend(map(lambda content: content["Key"], contents))
return objects_list
def yield_object_keys(
self,
bucket_name: str,
folder: Optional[str] = None,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
) -> Generator[str, None, None]:
"""Yield object keys that starts with the key of the folder
Parameters
---------
bucket_name : str
S3 bucket name
folder : str
key to the folder
prefix : str, optional,
If not None, only yield object keys that starts with the specified prefix.
Default None
suffix : str, optional
If not None, only yield object keys that ends with the specified suffix.
Default None
Yields
---------
key : str
object key
"""
paginator_list_objects_v2 = self.__client.get_paginator("list_objects_v2")
if folder:
response_iter = paginator_list_objects_v2.paginate(
Bucket=bucket_name, Prefix=folder
)
else:
response_iter = paginator_list_objects_v2.paginate(Bucket=bucket_name)
for response in response_iter:
contents = response.get("Contents", [])
for content in contents:
key = content["Key"]
if prefix and not key.startswith(prefix):
continue
if suffix and not key.endswith(suffix):
continue
yield key
def yield_batch_uris(
self,
source_folder: str,
batch_size: int = 1024,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
exclude_stems: Set[str] = frozenset(),
) -> Generator[List[str], None, None]:
"""Yield list of s3uris
Parameters
---------
source_folder : str
an s3 uri to a folder
batch_size : int
prefix : str, optional
If not None, yield lists of s3_uris that starts with the specified prefix.
Default None
suffix : str, optional
If not None, yield lists of s3_uris that ends with the specified suffix.
Default None
exclude_stems : Set[str]
If not empty, s3_uri with the same file_path.stem in the exclude_stems set wont appear in
the yielded lists
Yields
---------
s3_uris : List[str]
"""
assert source_folder.startswith(
"s3://"
), f"{source_folder} is not a valid s3 uri"
bucket_name, folder_key = get_bucket_name_key(source_folder)
key_gen = self.yield_object_keys(
bucket_name, folder_key, prefix=prefix, suffix=suffix
)
batch_uris = []
for key in key_gen:
s3_uri = f"s3://{bucket_name}/" + key
if len(batch_uris) < batch_size:
if Path(s3_uri).stem not in exclude_stems:
batch_uris.append(s3_uri)
else:
yield batch_uris
batch_uris = [s3_uri]
if batch_uris:
yield batch_uris
def upload_file(self, file_name: str, bucket_name: str, destination: str, **kwargs):
"""Upload a local file to S3 bucket
Parameters
----------
file_name : str
bucket_name : str
destination : str
the relative file path in the bucket
kwargs : dict
Examples
---------
>>> client = S3Client()
>>> client.upload_file("./hello.txt","my_bucket","test/hello.txt")
"""
self.__client.upload_file(file_name, bucket_name, destination, **kwargs)
def upload_folder(
self, folder: str, bucket_name: str, destination_folder: str, **kwargs
):
"""Upload all items under local folder to the destination folder in s3 bucket
Parameters
----------
folder : str
bucket_name : str
destination_folder : str
kwargs : dict
Examples
---------
>>> client = S3Client()
>>> client.upload_folder("./local_folder","my_bucket","documents")
"""
object_paths = list(Path(folder).rglob("*"))
def _upload_object(object_path: Path):
"""Upload a single local object to destination"""
relative_path = object_path.relative_to(folder)
key = str(PurePosixPath(destination_folder, str(relative_path)))
if object_path.is_dir():
self.create_folder(bucket_name, key)
else:
self.upload_file(str(object_path), bucket_name, key, **kwargs)
n_objects = len(object_paths)
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as executor:
jobs = executor.map(_upload_object, object_paths)
list(tqdm(jobs, desc="Uploading...", total=n_objects))
def download_file(self, bucket_name: str, object_name: str, local_file_name: str):
"""Download a file from S3 bucket to local disk
Examples
---------
>>> client = S3Client()
>>> client.download_file("my_bucket","remote_folder/file.txt","local_folder/file.txt")
"""
self.__client.download_file(bucket_name, object_name, local_file_name)
def download_s3_uris(self, s3_uris: List[str], to_folder="download"):
"""Downloading a list of s3 objects to the target folder
Parameters
---------
s3_uris : str
full s3 uri, e.g. s3://bucket_name/object_key
to_folder : str
folder under which objects will be downloaded
"""
local_path = Path(to_folder)
local_path.mkdir(exist_ok=True, parents=True)
def download_uri(s3_uri: str):
assert s3_uri.lower().startswith("s3"), f"{s3_uri} is not a valid s3 uri"
bucket_name, key = get_bucket_name_key(s3_uri)
self.download_file(bucket_name, key, str(local_path / Path(s3_uri).name))
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as executor:
jobs = executor.map(download_uri, s3_uris)
list(
tqdm(
jobs,
desc="Downloading files...",
total=len(s3_uris),
mininterval=5,
maxinterval=15,
)
)
def download_folder(self, bucket_name: str, remote_folder: str, destination: str):
"""Download a remote folder in S3 bucket to a local folder"""
# list all object keys that start with remote_folder
object_keys = self.list_objects(bucket_name, remote_folder)
def download_object(key: str):
"""Download a single object"""
relative_path = PurePosixPath(key).relative_to(remote_folder)
local_path = Path(destination, relative_path)
if key.endswith("/"): # if object is a directory
local_path.mkdir(exist_ok=True, parents=True)
else: # elif object is a file
local_path.parent.mkdir(exist_ok=True, parents=True)
self.download_file(bucket_name, key, str(local_path))
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as executor:
jobs = executor.map(download_object, object_keys)
n_jobs = len(object_keys)
list(
tqdm(
jobs,
desc="Downloading...",
total=n_jobs,
mininterval=5,
maxinterval=15,
)
)
def generate_url(self, bucket: str, key: str, expire: int = 3600):
"""Generate a download url for sharing a single file
Parameters
---------
bucket : str
bucket name
key : str
key (path) to the file object
expire : int, default : 3600
Seconds after which the generated url will expire
Returns
---------
url : str
"""
url = self.__client.generate_presigned_url(
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=expire
)
return url
def generate_urls(
self, bucket: str, folder: str = None, expire: int = 3600
) -> List[str]:
"""Recursively generate urls for each file under the folder
Parameters
---------
bucket : str
bucket name
folder : str, optional, Default : None
path to folder. If None, generate urls for all objects under the bucket
expire : int, optional, default : 3600
Seconds after which the generated url will expire
Returns
---------
urls : List[str]
"""
paths = self.list_objects(bucket, folder)
return list(
map(
lambda path: self.generate_url(bucket, path, expire),
filter(lambda path: not path.endswith("/"), paths),
)
)
def stream_objects(self, s3_uris: List[str]) -> Generator:
"""Yields file objects that can be read given a list of full S3 uris.
Parameters
---------
s3_uris : List[str]
"""
get_object = self.__client.get_object
for s3_uri in s3_uris:
assert s3_uri.lower().startswith("s3"), f"{s3_uri} is not a valid s3 uri"
bucket_name, key = get_bucket_name_key(s3_uri)
response = get_object(Bucket=bucket_name, Key=key)
yield response["Body"]
def copy(self, source: List[str], destination: List[str]):
"""Copy objects from source bucket to destination bucket
Parameters
---------
source : List[str]
a list of s3 uris
destination : List[str]
a list of s3 uris
"""
_copy = self.__client.copy
def __copy(src_bucket_key_pair, dest_bucket_key_pair):
src_bucket, src_key = src_bucket_key_pair
dest_bucket, dest_key = dest_bucket_key_pair
_copy({"Bucket": src_bucket, "Key": src_key}, dest_bucket, dest_key)
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as executor:
jobs = executor.map(
__copy,
map(get_bucket_name_key, source),
map(get_bucket_name_key, destination),
)
n_jobs = len(source)
list(
tqdm(
jobs,
desc="Copying...",
total=n_jobs,
mininterval=5,
maxinterval=15,
)
)
def delete(self, bucket: str, keys: List[str], **kwargs):
"""Delete a list of objects from bucket"""
max_objects_per_request = 1000
delete_objects = partial(self.__client.delete_objects, Bucket=bucket, **kwargs)
def batch_key_gen():
key_gen = iter(keys)
batch = [
{"Key": key} for key, _ in zip(key_gen, range(max_objects_per_request))
]
while batch:
yield batch
batch = [
{"Key": key}
for key, _ in zip(key_gen, range(max_objects_per_request))
]
with ThreadPoolExecutor(max_workers=2 * cpu_count()) as pool:
for batch_keys in batch_key_gen():
pool.submit(delete_objects, Delete={"Objects": batch_keys})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment