|
# https://cloud.google.com/blog/topics/developers-practitioners/simplify-model-serving-custom-prediction-routines-vertex-ai https://cloud.google.com/blog/topics/developers-practitioners/simplify-model-serving-custom-prediction-routines-vertex-ai https://cloud.google.com/blog/topics/developers-practitioners/simplify-model-serving-custom-prediction-routines-vertex-ai |
|
|
|
from google.cloud.aiplatform.utils import prediction_utils |
|
from google.cloud.aiplatform.prediction.predictor import Predictor |
|
|
|
import xgboost |
|
from xgboost import DMatrix |
|
import importlib |
|
|
|
import os |
|
from typing import Union, List |
|
|
|
ENDPOINT_PATH = "endpoints.py" |
|
MODULE = ENDPOINT_PATH.split(".")[0] |
|
MODEL_PATH = "model.bst" |
|
|
|
class InvalidFunctionError(Exception): |
|
def __init__(self, func_name): |
|
super().__init__(f"{func_name} is not a supported function") |
|
|
|
import inspect |
|
def get_endpoints(module_name): |
|
""" |
|
get all the endpoints objects inherited from VertexEndpoint in the module |
|
""" |
|
module = importlib.import_module(module_name) |
|
#FIXME: circular dependencies |
|
subclass = module.VertexEndpoint |
|
return filter(lambda x: inspect.isclass(x) and x is not subclass and issubclass(x, subclass) , [ getattr(module, i) for i in dir(module)]) |
|
|
|
class XGBoostPredictor(Predictor): |
|
def __init__(self): |
|
self.predict_functions = {} |
|
|
|
def load(self, artifacts_uri: str): |
|
# https://github.com/googleapis/python-aiplatform/blob/bb27619d71fe237690f9c14a37461f1ca839822b/google/cloud/aiplatform/utils/prediction_utils.py#L153 |
|
# artifacts_url can be a local path |
|
prediction_utils.download_model_artifacts(artifacts_uri) |
|
|
|
model = xgboost.Booster() |
|
model.load_model(MODEL_PATH) |
|
self._model = model |
|
|
|
if os.path.exists(ENDPOINT_PATH): |
|
for endpoint in get_endpoints(MODULE): |
|
self.predict_functions[endpoint.__name__.lower()] = endpoint(model) |
|
|
|
def get_handler(self, request): |
|
for endpoint in get_endpoints(MODULE): |
|
if endpoint.can_handle(request): |
|
return endpoint(self._model) |
|
|
|
def predict(self, instances: List[List[float]]) -> list: |
|
return self._model.predict(DMatrix(instances)) |
|
|
|
def predict_custom(self, name, instances: DMatrix) -> list: |
|
if name not in self.predict_functions: |
|
raise InvalidFunctionError(name) |
|
|
|
endpoint = self.predict_functions[name] |
|
return endpoint.handle(instances) |
|
|
|
def postprocess(self, prediction_results: list) -> dict: |
|
return {"predictions": prediction_results} |