Skip to content

Instantly share code, notes, and snippets.

@simon-mo
Created July 19, 2022 15:54
Show Gist options
  • Select an option

  • Save simon-mo/59c70bb40d2c35ac5830ac307836b6b9 to your computer and use it in GitHub Desktop.

Select an option

Save simon-mo/59c70bb40d2c35ac5830ac307836b6b9 to your computer and use it in GitHub Desktop.
from io import BytesIO
from pydantic import BaseModel
from pprint import pprint
import requests
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import numpy as np
import ray
from ray import serve
from ray.serve.drivers import DAGDriver
from ray.dag.input_node import InputNode
class ContentInput(BaseModel):
image_url: str
user_id: int
@serve.deployment
def downloader(inp: "ContentInput"):
"""Download HTTP content, in production this can be business logic downloading from other services"""
image_bytes = requests.get(inp.image_url).content
return image_bytes
@serve.deployment
class Preprocessor:
"""Image preprocessor with imagenet normalization."""
def __init__(self):
self.preprocessor = transforms.Compose(
[
transforms.Resize([224, 224]),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Lambda(lambda t: t[:3, ...]),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def preprocess(self, image_payload_bytes: bytes) -> np.ndarray:
pil_image = Image.open(BytesIO(image_payload_bytes)).convert("RGB")
input_array = self.preprocessor(pil_image).unsqueeze(0)
return input_array
@serve.deployment
class ImageClassifier:
def __init__(self, version: int):
self.version = version
self.model = models.resnet50(weights="IMAGENET1K_V1").eval()
label_file = requests.get("https://ray-serve-blog.s3.us-west-2.amazonaws.com/imagenet/imagenet_classes.txt").text
self.categories = [s.strip() for s in label_file.split("\n")]
def forward(self, input_tensor):
with torch.no_grad():
output_tensor = self.model(input_tensor)
probabilities = torch.nn.functional.softmax(output_tensor[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
classify_result = [
(
self.categories[top5_catid[i]],
top5_prob[i].item(),
)
for i in range(top5_prob.size(0))
]
return {
"classify_result": classify_result,
"model_version": self.version,
}
@serve.deployment
class DynamicDispatch:
def __init__(self, *classifier_models):
self.classifier_models = classifier_models
async def forward(self, inp_tensor, inp: "ContentInput"):
chosen_idx = inp.user_id % len(self.classifier_models)
chosen_model = self.classifier_models[chosen_idx]
return await chosen_model.forward.remote(inp_tensor)
@serve.deployment
class ImageDetector:
def __init__(self):
self.model = models.detection.maskrcnn_resnet50_fpn(weights="COCO_V1").eval()
def forward(self, input_tensor):
with torch.no_grad():
return [
(o["labels"].numpy().tolist(), o["boxes"].numpy().tolist())
for o in self.model(input_tensor)
]
@serve.deployment
def combine(classify_output, detection_output):
return {
"resnet_version": classify_output["model_version"],
"classify_result": classify_output["classify_result"],
"detection_output": detection_output,
}
# Let's Build the DAG here !!
preprocessor = Preprocessor.bind()
classifiers = [ImageClassifier.bind(i) for i in range(3)]
dispatcher = DynamicDispatch.bind(*classifiers)
detector = ImageDetector.bind()
def input_adapter(image_url: str, user_id: int):
return ContentInput(image_url=image_url, user_id=user_id)
serve.start()
with InputNode() as user_input:
image_bytes = downloader.bind(user_input)
image_tensor = preprocessor.preprocess.bind(image_bytes)
classification_output = dispatcher.forward.bind(image_tensor, user_input)
detection_output = detector.forward.bind(image_tensor)
local_dag = combine.bind(classification_output, detection_output)
serve_entrypoint = DAGDriver.bind(local_dag, http_adapter=input_adapter)
image_url = "https://ray-serve-blog.s3.us-west-2.amazonaws.com/imagenet/n01833805_hummingbird.jpeg"
# print("Started running DAG locally...")
# user_input = ContentInput(image_url=image_url, user_id=1)
# rst = ray.get(local_dag.execute(user_input))
# pprint(rst)
# print("Running it with Ray Serve")
# serve.run(serve_entrypoint)
# resp = requests.get("http://localhost:8000/", params={"image_url": image_url, "user_id": 1})
# print("response json")
# pprint(resp.json())
# in CLI
# run `serve run app.serve_entrypoint`
# go to localhost:8000/docs and use the OpenAPI UI
# or
# curl -X 'GET' \
# 'http://localhost:8000/?image_url=https%3A%2F%2Fgithub.com%2FEliSchwartz%2Fimagenet-sample-images%2Fblob%2Fmaster%2Fn01833805_hummingbird.JPEG%3Fraw%3Dtrue&user_id=1' \
# -H 'accept: application/json'
from ray import serve
import time
@serve.deployment
class A:
def __init__(self):
print("new replica")
def __call__(self):
time.sleep(1)
return "hi"
config = {
"min_replicas": 0,
"max_replicas": 10,
"target_num_ongoing_requests_per_replica": 1,
# Speeding up for demo purpose.
"downscale_delay_s": 8,
"upscale_delay_s": 8,
"look_back_period_s": 6,
"metrics_interval_s": 2
}
a = A.options(autoscaling_config=config).bind()
apiVersion: ray.io/v1alpha1
kind: RayService
metadata:
name: rayservice-sample
spec:
serveDeploymentGraphConfig:
importPath: fruit.deployment_graph
runtimeEnv: |
working_dir: "https://github.com/ray-project/test_dag/archive/c620251044717ace0a4c19d766d43c5099af8a77.zip"
serveConfigs:
- name: MangoStand
numReplicas: 1
userConfig: |
price: 4
rayActorOptions:
numCpus: 0.1
- name: OrangeStand
numReplicas: 1
userConfig: |
price: 2
rayActorOptions:
numCpus: 0.1
- name: PearStand
numReplicas: 1
userConfig: |
price: 1
rayActorOptions:
numCpus: 0.1
- name: FruitMarket
numReplicas: 1
rayActorOptions:
numCpus: 0.1
- name: DAGDriver
numReplicas: 1
routePrefix: "/"
rayActorOptions:
numCpus: 0.1
rayClusterConfig:
rayVersion: 'nightly' # should match the Ray version in the image of the containers
# enableInTreeAutoscaling: true
######################headGroupSpecs#################################
# head group template and specs, (perhaps 'group' is not needed in the name)
headGroupSpec:
# Kubernetes Service Type, valid values are 'ClusterIP', 'NodePort' and 'LoadBalancer'
serviceType: ClusterIP
# the pod replicas in this group typed head (assuming there could be more than 1 in the future)
replicas: 1
# logical group name, for this called head-group, also can be functional
# pod type head or worker
# rayNodeType: head # Not needed since it is under the headgroup
# the following params are used to complete the ray start: ray start --head --block --redis-port=6379 ...
rayStartParams:
port: '6379' # should match container port named gcs-server
#include_webui: 'true'
object-store-memory: '100000000'
# webui_host: "10.1.2.60"
dashboard-host: '0.0.0.0'
num-cpus: '2' # can be auto-completed from the limits
node-ip-address: $MY_POD_IP # auto-completed as the head pod IP
block: 'true'
#pod template
template:
metadata:
labels:
# custom labels. NOTE: do not define custom labels start with `raycluster.`, they may be used in controller.
# Refer to https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/
rayCluster: raycluster-sample # will be injected if missing
rayNodeType: head # will be injected if missing, must be head or wroker
groupName: headgroup # will be injected if missing
# annotations for pod
annotations:
key: value
spec:
containers:
- name: ray-head
#image: rayproject/ray:1.12.1
image: rayproject/ray:nightly
imagePullPolicy: Always
#image: bonsaidev.azurecr.io/bonsai/lazer-0-9-0-cpu:dev
env:
- name: MY_POD_IP
valueFrom:
fieldRef:
fieldPath: status.podIP
resources:
limits:
cpu: 2
memory: 2Gi
requests:
cpu: 2
memory: 2Gi
ports:
- containerPort: 6379
name: gcs-server
- containerPort: 8265 # Ray dashboard
name: dashboard
- containerPort: 10001
name: client
- containerPort: 8000
name: serve
workerGroupSpecs:
# the pod replicas in this group typed worker
- replicas: 1
minReplicas: 1
maxReplicas: 5
# logical group name, for this called small-group, also can be functional
groupName: small-group
# if worker pods need to be added, we can simply increment the replicas
# if worker pods need to be removed, we decrement the replicas, and populate the podsToDelete list
# the operator will remove pods from the list until the number of replicas is satisfied
# when a pod is confirmed to be deleted, its name will be removed from the list below
#scaleStrategy:
# workersToDelete:
# - raycluster-complete-worker-small-group-bdtwh
# - raycluster-complete-worker-small-group-hv457
# - raycluster-complete-worker-small-group-k8tj7
# the following params are used to complete the ray start: ray start --block --node-ip-address= ...
rayStartParams:
node-ip-address: $MY_POD_IP
block: 'true'
#pod template
template:
metadata:
labels:
key: value
# annotations for pod
annotations:
key: value
spec:
initContainers:
# the env var $RAY_IP is set by the operator if missing, with the value of the head service name
- name: init-myservice
image: busybox:1.28
command: ['sh', '-c', "until nslookup $RAY_IP.$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).svc.cluster.local; do echo waiting for myservice; sleep 2; done"]
containers:
- name: machine-learning # must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc'
image: rayproject/ray:nightly
imagePullPolicy: Always
# environment variables to set in the container.Optional.
# Refer to https://kubernetes.io/docs/tasks/inject-data-application/define-environment-variable-container/
env:
- name: RAY_DISABLE_DOCKER_CPU_WARNING
value: "1"
- name: TYPE
value: "worker"
- name: CPU_REQUEST
valueFrom:
resourceFieldRef:
containerName: machine-learning
resource: requests.cpu
- name: CPU_LIMITS
valueFrom:
resourceFieldRef:
containerName: machine-learning
resource: limits.cpu
- name: MEMORY_LIMITS
valueFrom:
resourceFieldRef:
containerName: machine-learning
resource: limits.memory
- name: MEMORY_REQUESTS
valueFrom:
resourceFieldRef:
containerName: machine-learning
resource: requests.memory
- name: MY_POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: MY_POD_IP
valueFrom:
fieldRef:
fieldPath: status.podIP
ports:
- containerPort: 80
name: client
lifecycle:
preStop:
exec:
command: ["/bin/sh","-c","ray stop"]
resources:
limits:
cpu: "2"
memory: "2Gi"
requests:
cpu: "500m"
memory: "2Gi"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment