Skip to content

Instantly share code, notes, and snippets.

@mjdietzx
Last active July 13, 2018 04:46
Show Gist options
  • Save mjdietzx/e50fff7be5fa3ba009e99bcc8e28f62d to your computer and use it in GitHub Desktop.
Save mjdietzx/e50fff7be5fa3ba009e99bcc8e28f62d to your computer and use it in GitHub Desktop.
AWS Lambda pytorch deep learning lambda function (ResNet-18 pre-trained on ImageNet)
import io
import boto3
import PIL.Image
import torch
from torch.utils import model_zoo
import torchvision
s3_client = boto3.client('s3')
valid_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(size=256, interpolation=PIL.Image.ANTIALIAS),
torchvision.transforms.CenterCrop(size=224),
torchvision.transforms.ToTensor(),
])
class SetupModel(object):
model = torchvision.models.resnet.ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2])
def __init__(self, f):
self.f = f
model_url = torchvision.models.resnet.model_urls['resnet18'] # should encrypt models in real life
self.model.load_state_dict(model_zoo.load_url(model_url, model_dir='/tmp')) # be careful writing model to disc
self.model.eval()
def __call__(self, *args, **kwargs):
return self.f(*args, **kwargs)
def predict(r):
input_batch = []
with PIL.Image.open(io.BytesIO(r)) as im:
im = im.convert('RGB')
input_batch.append(valid_transform(im))
input_batch_var = torch.autograd.Variable(torch.stack(input_batch, dim=0), volatile=True)
return SetupModel.model(input_batch_var)
@SetupModel # download the model when servicing request and enable it to persist across requests in memory
def handler(event, _):
for record in event['Records']:
bucket = record['s3']['bucket']['name']
key = record['s3']['object']['key']
# being paranoid and not writing user data to disc (should also be encrypted in real life)
model_output = predict(s3_client.get_object(Bucket=bucket, Key=key)['Body'].read())
return str(model_output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment