Skip to content

Instantly share code, notes, and snippets.

@perryism
Created April 12, 2023 16:21
Show Gist options
  • Save perryism/b17d1acd38f670f1e714e7147b52372e to your computer and use it in GitHub Desktop.
Save perryism/b17d1acd38f670f1e714e7147b52372e to your computer and use it in GitHub Desktop.
serving xgboost model in Vertex AI with customer container

Overview

I need to create a vertex ai endpoint with different functions. There are two options. 1. one endpoint per function. 1. one endpoint for all functions. This is an attempt to handle different functions in a single endpoint.

Base image

The base image requires three files.

  1. requirements.txt - it will be used to build the base docker image
  2. handler.py - it is the fastapi handler which we have direct access to request and response objects.
  3. predictor.py - it is supposed to be where the prediction implementations reside

Runtime

During runtime of the container, vertex SDK will run the container using the following command.

docker run --rm -p 8082:8080 --name vertex_ai_local -e AIP_HTTP_PORT=8080 -e AIP_STORAGE_URI=/tmp_cpr_local_model -e AIP_HEALTH_ROUTE=/health -e AIP_PREDICT_ROUTE=/predict -v /model_artifact:/tmp_cpr_local_model us-central1-docker.pkg.dev/prosper-nonprod/ml-container-images/xgboost-custom

model_artifact is where the model artifacts are saved. The barebone files are expected to be in the artfact folder are as follows:

  1. model.bst - the xgboost binary file
  2. endpoints.py - this is where the prediction logic lives
import xgboost
from xgboost import DMatrix
from typing import List, Dict, Union
class VertexEndpoint:
@classmethod
def can_handle(cls, request):
return request.headers.get("FUNCTION-NAME", "") == cls.__name__.lower()
def __init__(self, model):
self.model = model
def preprocess(self, prediction_input: dict) ->List[Vector]:
instances = prediction_input["instances"]
if len(instances) > 0 and type(instances[0]) is dict:
instances = inputs_to_matrix(instances, features())
return instances
def run(self, instances: List[Union[Vector, Dict[str, float]]]) -> List[Vector]:
return self.postprocess(
self.handle(self.preprocess(instances))
)
def handle(self, instances: List[Union[Vector, Dict[str, float]]]) -> List[Vector]:
raise NotImplementedError
def postprocess(self, prediction_results) -> List[Union[Vector, Dict[str, float]]]:
return prediction_results
# to test this function, the request header requires "FUNCTION-NAME: predict"
class Predict(VertexEndpoint):
def handle(self, instances: List[Vector]) -> List[Vector]:
return self.model.predict(DMatrix(instances)).tolist()
# to test this function, the request header requires "FUNCTION-NAME: predict2"
class Predict2(VertexEndpoint):
def handle(self, instances: List[Vector]) -> List[Vector]:
return (self.model.predict(DMatrix(instances)) * 2).tolist()
from google.cloud.aiplatform.prediction.handler import PredictionHandler
from fastapi import Response
import json
class CprHandler(PredictionHandler):
async def handle(self, request):
request_body = await request.body()
prediction_instances = json.loads(request_body)
try:
handler = self._predictor.get_handler(request)
prediction_results = handler.postprocess(
handler.handle(handler.preprocess(prediction_instances))
)
return Response(content=json.dumps(self._predictor.postprocess(prediction_results)))
except Exception as ex:
return Response(content=json.dumps({"error": str(ex)}), status_code=400)
# https://cloud.google.com/blog/topics/developers-practitioners/simplify-model-serving-custom-prediction-routines-vertex-ai https://cloud.google.com/blog/topics/developers-practitioners/simplify-model-serving-custom-prediction-routines-vertex-ai https://cloud.google.com/blog/topics/developers-practitioners/simplify-model-serving-custom-prediction-routines-vertex-ai
from google.cloud.aiplatform.utils import prediction_utils
from google.cloud.aiplatform.prediction.predictor import Predictor
import xgboost
from xgboost import DMatrix
import importlib
import os
from typing import Union, List
ENDPOINT_PATH = "endpoints.py"
MODULE = ENDPOINT_PATH.split(".")[0]
MODEL_PATH = "model.bst"
class InvalidFunctionError(Exception):
def __init__(self, func_name):
super().__init__(f"{func_name} is not a supported function")
import inspect
def get_endpoints(module_name):
"""
get all the endpoints objects inherited from VertexEndpoint in the module
"""
module = importlib.import_module(module_name)
#FIXME: circular dependencies
subclass = module.VertexEndpoint
return filter(lambda x: inspect.isclass(x) and x is not subclass and issubclass(x, subclass) , [ getattr(module, i) for i in dir(module)])
class XGBoostPredictor(Predictor):
def __init__(self):
self.predict_functions = {}
def load(self, artifacts_uri: str):
# https://github.com/googleapis/python-aiplatform/blob/bb27619d71fe237690f9c14a37461f1ca839822b/google/cloud/aiplatform/utils/prediction_utils.py#L153
# artifacts_url can be a local path
prediction_utils.download_model_artifacts(artifacts_uri)
model = xgboost.Booster()
model.load_model(MODEL_PATH)
self._model = model
if os.path.exists(ENDPOINT_PATH):
for endpoint in get_endpoints(MODULE):
self.predict_functions[endpoint.__name__.lower()] = endpoint(model)
def get_handler(self, request):
for endpoint in get_endpoints(MODULE):
if endpoint.can_handle(request):
return endpoint(self._model)
def predict(self, instances: List[List[float]]) -> list:
return self._model.predict(DMatrix(instances))
def predict_custom(self, name, instances: DMatrix) -> list:
if name not in self.predict_functions:
raise InvalidFunctionError(name)
endpoint = self.predict_functions[name]
return endpoint.handle(instances)
def postprocess(self, prediction_results: list) -> dict:
return {"predictions": prediction_results}
xgboost==1.6.2
shap
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment