Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save nov05/95cb7edcbe2e8bb68c9d29bdc00b9ca8 to your computer and use it in GitHub Desktop.
Save nov05/95cb7edcbe2e8bb68c9d29bdc00b9ca8 to your computer and use it in GitHub Desktop.

🟒 AWS S3 data to SageMaker machine learning training


Code snippets are from the following sources:

import torch, time
from statistics import mean, variance
dataset = get_dataset()
dl = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=4)
stats_lst = []
t0 = time.perf_counter()
for batch_idx, batch in enumerate(dl, start=1):
    if batch_idx % 100 == 0:
        t = time.perf_counter() - t0
        print(f'Iteration {batch_idx} Time {t}')
        stats_lst.append(t)
        t0 = time.perf_counter()
mean_calc = mean(stats_lst[1:])
var_calc = variance(stats_lst[1:])
print(f'mean {mean_calc} variance {var_calc}')
## measure how the step time changes when running on the streamed data samples
import torch, time
from statistics import mean, variance
dataset=get_dataset()
dl=torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=4)
batch = next(iter(dl))
t0 = time.perf_counter()
for batch_idx in range(1,1000):
    train_step(batch)
    if batch_idx % 100 == 0:
        t = time.perf_counter() - t0
        print(f'Iteration {batch_idx} Time {t}')
        t0 = time.perf_counter()
  • Create random synthetic data
import webdataset as wds
import numpy as np
from PIL import Image
import io
out_tar = 'wds.tar'
sink = wds.TarWriter(out_tar)
im_width = 1024
im_height = 1024
num_classes = 256
for i in range(100):
    image = Image.fromarray(np.random.randint(0, high=256,
                  size=(im_height,im_width,3), dtype=np.uint8))
    label = Image.fromarray(np.random.randint(0, high=num_classes,
                  size=(im_height,im_width), dtype=np.uint8))
    image_bytes = io.BytesIO()
    label_bytes = io.BytesIO()
    image.save(image_bytes, format='PNG')
    label.save(label_bytes, format='PNG')
    sample = {"__key__": str(i),
              f'image': image_bytes.getvalue(),
              f'label': label_bytes.getvalue()}
    sink.write(sample)
  • FastFile input mode for SageMaker estimator
import os, webdataset
def get_dataset():
    ffm = os.environ['SM_CHANNEL_TRAINING']
    urls = [os.path.join(ffm, f'{i}.tar') for i in range(num_files)]
    dataset = (
        webdataset
            .WebDataset(urls, shardshuffle=True)  ## shard shuffle
            .shuffle(10)                          ## buffer shuffle
    )
    return dataset
@nov05
Copy link
Author

nov05 commented Jan 29, 2025

🟒 SageMaker ETL options

Summary:

  • SageMaker Processing Jobs: For Python-based ETL tasks, this is the most straightforward option. βœ…
  • Data Wrangler: If you have simple tabular transformations.
  • SageMaker Pipelines: For complex end-to-end pipelines that include both ETL and model training.
  • Combine SageMaker and Glue: For very large or complex data transformations.

  • My own working code βœ…βœ…βœ…
## get my own AWS account number
with open('../secrets/aws_account_number', 'r') as file:
    for line in file:
        aws_account_number = line.strip()
        break

## no need to run this cell 
# ## To pull ECR image from another AWS account 
# import boto3
# import subprocess
# import base64
# ecr_client = boto3.client('ecr', region_name='us-east-1')
# # Retrieve the authentication token from ECR
# response = ecr_client.get_authorization_token()
# authorization_data = response['authorizationData'][0]
# token = authorization_data['authorizationToken']
# registry_uri = authorization_data['proxyEndpoint']
# decoded_token = base64.b64decode(token).decode('utf-8')
# username, password = decoded_token.split(':')
# # Docker login command
# login_command = f"docker login --username {username} --password {password} {registry_uri}"
# subprocess.run(login_command, shell=True, check=True)
# # Now you can use this image in your SageMaker processing job 
## TODO: Perform any data cleaning or data preprocessing
from sagemaker.processing import ScriptProcessor
processor = ScriptProcessor(
    command=['python3'],
    ## You can use a custom image or use the default SageMaker image
    image_uri=f'{aws_account_number}.dkr.ecr.us-east-1.amazonaws.com/udacity/p5-amazon-bin-images:latest', 
    role=sagemaker_role_arn,  # Execution role
    instance_count=1,
    instance_type='ml.t3.large',  # Use the appropriate instance type
    volume_size_in_gb=10,  # Minimal disk space since we're streaming
    base_job_name='p5-amazon-bin-images' 
)
processor.run(
    code='../scripts_process/test_convert_to_webdataset.py',  # Your script to process data
    arguments=[
        '--SM_INPUT_BUCKET', 'aft-vbi-pds',
        '--SM_INPUT_PREFIX_IMAGES', 'bin-images/',
        '--SM_INPUT_PREFIX_METADATA', 'metadata/',
        '--SM_OUTPUT_BUCKET', 'p5-amazon-bin-images',
        '--SM_OUTPUT_PREFIX', 'webdataset/',
    ]
)

  • My code: WebDataset.ShardWriter() βœ…βœ…βœ…
    type_prefix = 'train/' or 'val/' or 'test/'
def convert_dataset(type_prefix, file_list, maxcount=1000):
    with wds.ShardWriter("shard-%06d.tar", maxcount=maxcount) as sink:
        for image_id,label in file_list:
            image_key = f'{input_prefix_images}{image_id}.jpg'
            try:  # Ensure the corresponding JSON file exists
                image_data = read_s3_file(input_bucket, image_key)
            except Exception as e:
                print(f"⚠️ Skipping image '{image_key}' due to error: {e}")
                continue
            # Save as WebDataset sample
            sink.write({
                "__key__": f"{image_id}",
                "image": image_data,
                "label": label,
            })
    # Upload the tar file to S3
    tar_list = glob.glob("shard-*.tar")
    for tar_file in tar_list:
        file_name = os.path.basename(tar_file)
        s3_key = os.path.join(output_prefix, type_prefix, file_name)
        s3_client.upload_file(tar_file, output_bucket, s3_key)
    print(f"🟒 Successfully uploaded tar files to "
          f"s3://{output_bucket}/{output_prefix}{type_prefix}:\n"
          f"    {tar_list}")
  • My code: WebDataset.TarWriter() + io.BytesIO() βœ…βœ…βœ…
def convert_dataset(image_keys, num_tar_files):
    # Create a tar file in memory and write WebDataset format
    tar_stream = io.BytesIO()
    with wds.TarWriter(tar_stream) as sink:
        for image_key in image_keys:
            if not (image_key.endswith('.jpg') or image_key.endswith('.jpeg')):
                print(f"⚠️ Skipping non-image file: {image_key}")
                continue
            base_name = os.path.splitext(image_key.split('/')[-1])[0]
            try:  # Ensure the corresponding JSON file exists
                metadata_data = read_s3_file(input_bucket, f'{input_prefix_metadata}{base_name}.json')
                image_data = read_s3_file(input_bucket, image_key)
            except Exception as e:
                print(f"⚠️ Skipping image '{image_key}' due to error: {e}")
                continue
            # Save as WebDataset sample
            sink.write({
                "__key__": f"{base_name}",
                "image": image_data,
                "metadata": metadata_data
            })
    # Once the tar file is in memory, upload it back to S3
    tar_stream.seek(0)
    file_name = f'{output_prefix}data_{num_tar_files}.tar'
    s3_client.upload_fileobj(tar_stream, output_bucket, file_name)
    print(f"🟒 Successfully uploaded tar file to s3://{output_bucket}/{file_name}")

@nov05
Copy link
Author

nov05 commented Jan 29, 2025

  • ScriptProcessor official documentation
  • My tutorial: Create custom docker image for SageMaker data processing jobs, create AWS ECR private repo, and upload the image to the repo βœ…βœ…βœ…
  • AWS re:Post, pull ECR image from the repo of another account βœ…

@nov05
Copy link
Author

nov05 commented Jan 29, 2025

import webdataset as wds
from huggingface_hub import get_token
from torch.utils.data import DataLoader

hf_token = get_token()
url = "https://huggingface.co/datasets/timm/imagenet-12k-wds/resolve/main/imagenet12k-train-{{0000..1023}}.tar"
url = f"pipe:curl -s -L {url} -H 'Authorization:Bearer {hf_token}'"
dataset = wds.WebDataset(url).decode()
dataloader = DataLoader(dataset, batch_size=64, num_workers=4)
buffer_size = 1000
dataset = (
    wds.WebDataset(url, shardshuffle=True)
    .shuffle(buffer_size)
    .decode()
)

Generally, datasets in WebDataset formats are already shuffled and ready to feed to a DataLoader. But you can still reshuffle the data with WebDataset’s approximate shuffling.

In addition to shuffling the list of shards, WebDataset uses a buffer to shuffle a dataset without any cost to speed.

@nov05
Copy link
Author

nov05 commented Jan 31, 2025

  • TorchVision dataset
from torchvision import datasets, transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(224),
    transforms.ColorJitter(brightness=0.2, 
                           contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225]),
])  
train_dataset = datasets.ImageFolder(task.config.train, transform=train_transform)
  • WebDataset dataset
import webdataset as wds
from torchvision import transforms

def identity(x):
     return x

path = "s3://p5-amazon-bin-images/webdataset/train/shard-{{000000..000001}}.tar"
task.config.train = f"pipe:curl -s -L {path}"
# Create the WebDataset pipeline
train_dataset = (
    wds.WebDataset(task.config.train, shardshuffle=True)  ## Shuffle shards
        .shuffle(1000)  # Shuffle dataset
        .decode("pil")  
        .to_tuple("jpg", "cls")  # Tuple of image and label; specify file extensions
        .map_tuple(train_transform, identity)  # Apply the train transforms to the image
)
# Wrap the dataset in a DataLoader for batching
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=task.config.batch_size, 
    num_workers=task.config.num_cpu)
# Example usage in a training loop
for batch_images, batch_labels in train_loader:
    # Training code here
    print(batch_images.shape, batch_labels.shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment