Last active
October 19, 2021 20:04
-
-
Save josemarcosrf/15ed48e6c9277a11939edc2a1e068271 to your computer and use it in GitHub Desktop.
Stream data from an s3 like bucket (e.g: DigitalOcean spaces) directly into memory. (midv-500 dataset streaming example)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import tarfile | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
from collections import defaultdict | |
from rich import print | |
from tqdm import tqdm | |
from s3_stream import S3Stream | |
def midv_500_buffer_read(tar_io_buffer): | |
"""Reads a midv-500 dataset directory from a | |
io buffer with the byte data of a tar.gz file | |
Args: | |
tar_io_buffer (io.Buffer): teh tarfile io buffer | |
Yields: | |
TYPE: A tuple for each datapoint: (id, image, quad with mask vertices) | |
""" | |
# Read as tarfile | |
tar = tarfile.open(fileobj=tar_io_buffer) | |
data = defaultdict(dict) | |
# Gather all items in the tar file | |
for item in tar.getmembers(): | |
try: | |
_id = os.path.basename(item.name).split(".")[0] | |
if item.name.endswith(".tif"): | |
data[_id]["img"] = item | |
elif item.name.endswith(".json"): | |
data[_id]["mask"] = item | |
except Exception as e: | |
print(f"[red]Error gathering tar items: {e}[/red]") | |
# Yield items from the tarfile | |
for _id, dat in data.items(): | |
# Read the mask | |
mask = json.load(tar.extractfile(dat["mask"])).get("quad") | |
# Read the image | |
img_stream = tar.extractfile(dat["img"]) | |
img_array = np.asarray(bytearray(img_stream.read()), dtype=np.uint8) | |
img = cv2.imdecode(img_array, cv2.IMREAD_ANYCOLOR) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
yield (_id, img, mask) | |
def load_midv_from_s3(): | |
"""Streams items from a S3 bucket, where each item is tarfile. | |
Each tarfile is extracted in memory and each datapoint is iterated trhough | |
""" | |
assert ACCESS_KEY_ID, f"Missing 'ACCESS_KEY_ID'" | |
assert SECRET_ACCESS_KEY, f"Missing 'SECRET_ACCESS_KEY'" | |
streamer = S3Stream(ACCESS_KEY_ID, SECRET_ACCESS_KEY, ENDPOINT_URL) | |
print(streamer.list_buckets()) | |
for fname in streamer.list_keys(BUCKET_NAME, DATASET_PREFIX, pbar=False): | |
print(f":confetti_ball: {fname} :confetti_ball:") | |
io_buffer = streamer.stream_object(BUCKET_NAME, fname) | |
for (_id, img, mask) in midv_500_buffer_read(io_buffer): | |
print(f"[dim]{_id}: {mask} | {img.shape}[/dim]") | |
plt.imshow(img) | |
plt.show() | |
def load_midv_local(): | |
with open("some-local-path-in-here", "rb") as tar_f: | |
for (_id, img, mask) in midv_500_buffer_read(tar_f): | |
print(f"[dim]{_id}: {mask} | {img.shape}[/dim]") | |
plt.imshow(img) | |
plt.show() | |
if __name__ == "__main__": | |
load_midv_from_s3() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import boto3 | |
from tqdm import tqdm | |
# Read AWS S3 Keys from env | |
ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") | |
SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") | |
# These are to be configured via CLI | |
ENDPOINT_URL = "https://ams3.digitaloceanspaces.com" | |
BUCKET_NAME = "your-bucket-name" | |
DATASET_PREFIX = "your-obj-prefix" # like a prefix file filter | |
class S3Stream(object): | |
"""Simple class to stream data from a S3 bucket | |
Attributes: | |
client (boto3.client): boto3 client | |
""" | |
def __init__(self, access_key_id, secret_access_key, endpoint_url): | |
# infer the region from the S3 endpoint URL | |
region = endpoint_url.split(".")[0].replace("https://", "") | |
session = boto3.session.Session() | |
self.client = session.client( | |
"s3", | |
region_name=region, | |
endpoint_url=endpoint_url, | |
aws_access_key_id=access_key_id, | |
aws_secret_access_key=secret_access_key, | |
) | |
def list_buckets(self): | |
return [b.get("Name") for b in self.client.list_buckets().get("Buckets", [])] | |
def stream_object( | |
self, | |
bucket_name: str, | |
obj_key: str, | |
): | |
# Stream the file into an IO object | |
io_buffer = io.BytesIO() | |
self.client.download_fileobj(bucket_name, obj_key, io_buffer) | |
io_buffer.seek(0) | |
return io_buffer | |
def list_keys( | |
self, | |
bucket_name: str, | |
obj_prefix: str, | |
start_after: str = "", | |
continuation_token: str = "", | |
batch_size:int = 1000, | |
pbar: bool = True, | |
) -> iter: | |
"""Streams s3 objects starting with the specified 'obj_prefix' from | |
the given bucket. | |
Args: | |
bucket_name (str): S3 bucket name | |
obj_prefix (str): objects prefix to filter by | |
start_after (str, optional): item key from where to start fetching | |
continuation_token (str, optional): token used for pagination | |
pbar (bool, optional): if True show a progress bar | |
""" | |
response = self.client.list_objects_v2( | |
Bucket=bucket_name, | |
Prefix=obj_prefix, | |
StartAfter=start_after, | |
ContinuationToken=continuation_token, | |
MaxKeys=batch_size | |
) | |
objects = [c.get("Key") for c in response.get("Contents", [])] | |
obj_iter = tqdm(objects) if pbar else objects | |
for obj_key in obj_iter: | |
if pbar: | |
obj_iter.set_description(f"Downloading '{obj_key}'") | |
if obj_key.endswith("/"): | |
# Skip folders | |
continue | |
yield obj_key | |
if response.get("IsTruncated"): | |
for obj_key, io_buffer in self.stream( | |
bucket_name, | |
obj_prefix, | |
continuation_token=response.get("NextContinuationToken"), | |
pbar=pbar, | |
): | |
yield obj_key | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment