Skip to content

Instantly share code, notes, and snippets.

@simon-mo
Last active July 15, 2022 23:16
Show Gist options
  • Select an option

  • Save simon-mo/4f38dcca9491a289551580d197174dbc to your computer and use it in GitHub Desktop.

Select an option

Save simon-mo/4f38dcca9491a289551580d197174dbc to your computer and use it in GitHub Desktop.
from io import BytesIO
from pydantic import BaseModel
from pprint import pprint
import requests
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import numpy as np
import ray
from ray import serve
from ray.serve.drivers import DAGDriver
from ray.dag.input_node import InputNode
class ContentInput(BaseModel):
image_url: str
user_id: int
@serve.deployment
def downloader(inp: "ContentInput"):
"""Download HTTP content, in production this can be business logic downloading from other services"""
image_bytes = requests.get(inp.image_url).content
return image_bytes
@serve.deployment
class Preprocessor:
"""Image preprocessor with imagenet normalization."""
def __init__(self):
self.preprocessor = transforms.Compose(
[
transforms.Resize([224, 224]),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Lambda(lambda t: t[:3, ...]),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def preprocess(self, image_payload_bytes: bytes) -> np.ndarray:
pil_image = Image.open(BytesIO(image_payload_bytes)).convert("RGB")
input_array = self.preprocessor(pil_image).unsqueeze(0)
return input_array
@serve.deployment
class ImageClassifier:
def __init__(self, version: int):
self.version = version
self.model = models.resnet50(weights="IMAGENET1K_V1").eval()
label_file = requests.get("https://ray-serve-blog.s3.us-west-2.amazonaws.com/imagenet/imagenet_classes.txt").text
self.categories = [s.strip() for s in label_file.split("\n")]
def forward(self, input_tensor):
with torch.no_grad():
output_tensor = self.model(input_tensor)
probabilities = torch.nn.functional.softmax(output_tensor[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
classify_result = [
(
self.categories[top5_catid[i]],
top5_prob[i].item(),
)
for i in range(top5_prob.size(0))
]
return {
"classify_result": classify_result,
"model_version": self.version,
}
@serve.deployment
class DynamicDispatch:
def __init__(self, *classifier_models):
self.classifier_models = classifier_models
async def forward(self, inp_tensor, inp: "ContentInput"):
chosen_idx = inp.user_id % len(self.classifier_models)
chosen_model = self.classifier_models[chosen_idx]
return await chosen_model.forward.remote(inp_tensor)
@serve.deployment
class ImageDetector:
def __init__(self):
self.model = models.detection.maskrcnn_resnet50_fpn(weights="COCO_V1").eval()
def forward(self, input_tensor):
with torch.no_grad():
return [
(o["labels"].numpy().tolist(), o["boxes"].numpy().tolist())
for o in self.model(input_tensor)
]
@serve.deployment
def combine(classify_output, detection_output):
return {
"resnet_version": classify_output["model_version"],
"classify_result": classify_output["classify_result"],
"detection_output": detection_output,
}
# Let's Build the DAG here !!
preprocessor = Preprocessor.bind()
classifiers = [ImageClassifier.bind(i) for i in range(3)]
dispatcher = DynamicDispatch.bind(*classifiers)
detector = ImageDetector.bind()
def input_adapter(image_url: str, user_id: int):
return ContentInput(image_url=image_url, user_id=user_id)
serve.start()
with InputNode() as user_input:
image_bytes = downloader.bind(user_input)
image_tensor = preprocessor.preprocess.bind(image_bytes)
classification_output = dispatcher.forward.bind(image_tensor, user_input)
detection_output = detector.forward.bind(image_tensor)
local_dag = combine.bind(classification_output, detection_output)
serve_entrypoint = DAGDriver.bind(local_dag, http_adapter=input_adapter)
image_url = "https://ray-serve-blog.s3.us-west-2.amazonaws.com/imagenet/n01833805_hummingbird.jpeg"
# print("Started running DAG locally...")
# user_input = ContentInput(image_url=image_url, user_id=1)
# rst = ray.get(local_dag.execute(user_input))
# pprint(rst)
# print("Running it with Ray Serve")
serve.run(serve_entrypoint)
resp = requests.get("http://localhost:8000/", params={"image_url": image_url, "user_id": 1})
print("response json")
print(resp.json())
# in CLI
# run `serve run solution.serve_entrypoint`
# go to localhost:8000/docs and use the OpenAPI UI
# or
# curl -X 'GET' \
# 'http://localhost:8000/?image_url=https%3A%2F%2Fgithub.com%2FEliSchwartz%2Fimagenet-sample-images%2Fblob%2Fmaster%2Fn01833805_hummingbird.JPEG%3Fraw%3Dtrue&user_id=1' \
# -H 'accept: application/json'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment