Skip to content

Instantly share code, notes, and snippets.

@josemarcosrf
Last active October 19, 2021 20:04
Show Gist options
  • Save josemarcosrf/15ed48e6c9277a11939edc2a1e068271 to your computer and use it in GitHub Desktop.
Save josemarcosrf/15ed48e6c9277a11939edc2a1e068271 to your computer and use it in GitHub Desktop.
Stream data from an s3 like bucket (e.g: DigitalOcean spaces) directly into memory. (midv-500 dataset streaming example)
import json
import os
import tarfile
import numpy as np
import cv2
import matplotlib.pyplot as plt
from collections import defaultdict
from rich import print
from tqdm import tqdm
from s3_stream import S3Stream
def midv_500_buffer_read(tar_io_buffer):
"""Reads a midv-500 dataset directory from a
io buffer with the byte data of a tar.gz file
Args:
tar_io_buffer (io.Buffer): teh tarfile io buffer
Yields:
TYPE: A tuple for each datapoint: (id, image, quad with mask vertices)
"""
# Read as tarfile
tar = tarfile.open(fileobj=tar_io_buffer)
data = defaultdict(dict)
# Gather all items in the tar file
for item in tar.getmembers():
try:
_id = os.path.basename(item.name).split(".")[0]
if item.name.endswith(".tif"):
data[_id]["img"] = item
elif item.name.endswith(".json"):
data[_id]["mask"] = item
except Exception as e:
print(f"[red]Error gathering tar items: {e}[/red]")
# Yield items from the tarfile
for _id, dat in data.items():
# Read the mask
mask = json.load(tar.extractfile(dat["mask"])).get("quad")
# Read the image
img_stream = tar.extractfile(dat["img"])
img_array = np.asarray(bytearray(img_stream.read()), dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_ANYCOLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
yield (_id, img, mask)
def load_midv_from_s3():
"""Streams items from a S3 bucket, where each item is tarfile.
Each tarfile is extracted in memory and each datapoint is iterated trhough
"""
assert ACCESS_KEY_ID, f"Missing 'ACCESS_KEY_ID'"
assert SECRET_ACCESS_KEY, f"Missing 'SECRET_ACCESS_KEY'"
streamer = S3Stream(ACCESS_KEY_ID, SECRET_ACCESS_KEY, ENDPOINT_URL)
print(streamer.list_buckets())
for fname in streamer.list_keys(BUCKET_NAME, DATASET_PREFIX, pbar=False):
print(f":confetti_ball: {fname} :confetti_ball:")
io_buffer = streamer.stream_object(BUCKET_NAME, fname)
for (_id, img, mask) in midv_500_buffer_read(io_buffer):
print(f"[dim]{_id}: {mask} | {img.shape}[/dim]")
plt.imshow(img)
plt.show()
def load_midv_local():
with open("some-local-path-in-here", "rb") as tar_f:
for (_id, img, mask) in midv_500_buffer_read(tar_f):
print(f"[dim]{_id}: {mask} | {img.shape}[/dim]")
plt.imshow(img)
plt.show()
if __name__ == "__main__":
load_midv_from_s3()
import boto3
from tqdm import tqdm
# Read AWS S3 Keys from env
ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
# These are to be configured via CLI
ENDPOINT_URL = "https://ams3.digitaloceanspaces.com"
BUCKET_NAME = "your-bucket-name"
DATASET_PREFIX = "your-obj-prefix" # like a prefix file filter
class S3Stream(object):
"""Simple class to stream data from a S3 bucket
Attributes:
client (boto3.client): boto3 client
"""
def __init__(self, access_key_id, secret_access_key, endpoint_url):
# infer the region from the S3 endpoint URL
region = endpoint_url.split(".")[0].replace("https://", "")
session = boto3.session.Session()
self.client = session.client(
"s3",
region_name=region,
endpoint_url=endpoint_url,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
)
def list_buckets(self):
return [b.get("Name") for b in self.client.list_buckets().get("Buckets", [])]
def stream_object(
self,
bucket_name: str,
obj_key: str,
):
# Stream the file into an IO object
io_buffer = io.BytesIO()
self.client.download_fileobj(bucket_name, obj_key, io_buffer)
io_buffer.seek(0)
return io_buffer
def list_keys(
self,
bucket_name: str,
obj_prefix: str,
start_after: str = "",
continuation_token: str = "",
batch_size:int = 1000,
pbar: bool = True,
) -> iter:
"""Streams s3 objects starting with the specified 'obj_prefix' from
the given bucket.
Args:
bucket_name (str): S3 bucket name
obj_prefix (str): objects prefix to filter by
start_after (str, optional): item key from where to start fetching
continuation_token (str, optional): token used for pagination
pbar (bool, optional): if True show a progress bar
"""
response = self.client.list_objects_v2(
Bucket=bucket_name,
Prefix=obj_prefix,
StartAfter=start_after,
ContinuationToken=continuation_token,
MaxKeys=batch_size
)
objects = [c.get("Key") for c in response.get("Contents", [])]
obj_iter = tqdm(objects) if pbar else objects
for obj_key in obj_iter:
if pbar:
obj_iter.set_description(f"Downloading '{obj_key}'")
if obj_key.endswith("/"):
# Skip folders
continue
yield obj_key
if response.get("IsTruncated"):
for obj_key, io_buffer in self.stream(
bucket_name,
obj_prefix,
continuation_token=response.get("NextContinuationToken"),
pbar=pbar,
):
yield obj_key
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment