Skip to content

Instantly share code, notes, and snippets.

@madhurprash
Last active August 8, 2024 01:58
Show Gist options
  • Save madhurprash/0d1253a92466c7816115a91ab8c9972d to your computer and use it in GitHub Desktop.
Save madhurprash/0d1253a92466c7816115a91ab8c9972d to your computer and use it in GitHub Desktop.
Deploys a model from HuggingFace on Amazon SageMaker using the DJL DeepSpeed LMI (https://github.com/deepjavalibrary/djl-serving).
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "e9bff69e-dff9-4a1e-b89e-a1f0c2eed329",
"metadata": {},
"source": [
"### Deploy from HuggingFace on Amazon SageMaker using the DJL DeepSpeed LMI\n",
"---\n",
"\n",
"This notebook contains the implementation to download a model from `HuggingFace` and use it to deploy the model using a `DJL` Inference Container"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e93f93d1-6aa6-47be-8edb-bf17bbe652ea",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install -U huggingface_hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a86f02e0-9773-42f4-b2f9-dcb6c05e8ed1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"\"\"\"\n",
"Deploys a model from HuggingFace on Amazon SageMaker using the DJL DeepSpeek LMI\n",
"(https://github.com/deepjavalibrary/djl-serving).\n",
"\n",
"1. Model Configuration is used from the configured serving.properties file.\n",
"2. An HF Token is required to download a model from a gated repo from Hugging Face.\n",
"\"\"\"\n",
"import os\n",
"import glob\n",
"import time\n",
"import boto3\n",
"import logging\n",
"import tarfile\n",
"import tempfile\n",
"import sagemaker\n",
"from pathlib import Path\n",
"from urllib.parse import urlparse\n",
"from sagemaker.utils import name_from_base\n",
"from huggingface_hub import snapshot_download\n",
"from typing import Dict, List, Tuple, Optional"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9f2afe2-41d5-4b74-a27a-9466cd0018d6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# set logger to log information across this notebook\n",
"logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)\n",
"logger = logging.getLogger(__name__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f70286e-5497-4c8f-a3a0-a75d8292b096",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# get the role arn that is used to deploy the model on djl\n",
"sm_client = boto3.client(\"sagemaker\")\n",
"s3_client = boto3.client('s3')\n",
"role = sagemaker.get_execution_role()\n",
"logger.info(f\"Role that will be used to deploy the model: {role}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82e68d36-e9bc-4f2d-9cef-97bdb569a9fe",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# get the current region\n",
"boto3_session = boto3.session.Session()\n",
"region = boto3_session.region_name\n",
"logger.info(f\"Region in which the model will be deployed: {region}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c9b0f7d-6c6b-4105-a1c1-2acad1511fd4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# s3 bucket where the model is stored and pulled from during deployment\n",
"write_model_to_s3_bucket: str = \"<your-s3-bucket>\"\n",
"HF_TOKEN: str = \"<your-hf-token>\"\n",
"HF_MODEL_ID: str = \"meta-llama/Llama-2-13b-hf\"\n",
"\n",
"# serving properties that are used to deploy the model on djl\n",
"serving_properties: Dict = {\n",
" \"engine\": \"Python\",\n",
" \"option.tensor_parallel_degree\": \"4\",\n",
" \"option.model_id\": HF_MODEL_ID, \n",
" \"option.max_rolling_batch_size\": \"64\",\n",
" \"option.rolling_batch\": \"vllm\", \n",
" \"option.dtype\": \"fp16\"\n",
"}\n",
"\n",
"# configuration instance that is used to get variables in the deploy function\n",
"# set the download_from_hf_place_in_s3 to true if you want to download all model contents into s3. Set\n",
"# the default value of download_from_hf_place_in_s3 to False and refer to the model_id in the `option.model_id`\n",
"# within the serving_properties\n",
"deploy_with_djl_llama2_13b: Dict = {\n",
" \"name\": \"Llama2-7b-g4dn-djl-inference-0.26.0-deepspeed0.12.6-cu121\",\n",
" \"model_id\": HF_MODEL_ID,\n",
" \"download_from_hf_place_in_s3\": False,\n",
" \"model_name\": \"Llama-2-13b-hf\",\n",
" \"model_version\": \"*\", \n",
" \"ep_name\": \"Llama-2-7b-hf-g4dn\",\n",
" \"model_s3_path\": f\"s3://{write_model_to_s3_bucket}/meta-llama/Llama-2-13b-hf\", \n",
" \"serving.properties\": serving_properties,\n",
" \"instance_type\": \"ml.g4dn.12xlarge\",\n",
" \"instance_count\": \"1\",\n",
" \"image_uri\": f\"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.27.0-deepspeed0.12.6-cu121\"\n",
"}\n",
"\n",
"logger.info(f\"serving properties that will be used to deploy the model {HF_MODEL_ID}: {serving_properties}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "143517ff-ea90-41ff-859c-6ebf5f2696d2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def _download_model(model_id: str,\n",
" local_model_path: str,\n",
" allow_patterns: Optional[List] = [\"*\"]) -> str:\n",
" \"\"\"\n",
" Download the model files locally\n",
" \"\"\"\n",
" local_model_path = Path(local_model_path)\n",
" print(f\"Local model path: {local_model_path}\")\n",
" local_model_path.mkdir(exist_ok=True)\n",
" print(f\"Created the local directory: {local_model_path}\")\n",
"\n",
" model_download_path = snapshot_download(\n",
" repo_id=model_id,\n",
" cache_dir=local_model_path,\n",
" allow_patterns=allow_patterns,\n",
" use_auth_token=HF_TOKEN\n",
" )\n",
" print(f\"Uncompressed model downloaded into ... -> {model_download_path}\")\n",
" return model_download_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe266cec-a3b6-4f59-ba67-70e5ce0d97cc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def _upload_dir(localDir: str, awsInitDir: str, bucketName: str, tag: str =\"*.*\"):\n",
" s3 = boto3.resource('s3')\n",
" p = Path(localDir)\n",
" # Iterate over all directories and files within localDir\n",
" for path in p.glob('**/*'):\n",
" if path.is_file():\n",
" rel_path = path.relative_to(p)\n",
" awsPath = os.path.join(awsInitDir, str(rel_path)).replace(\"\\\\\", \"/\")\n",
" logger.info(f\"Uploading {path} to s3://{bucketName}/{awsPath}\")\n",
" logger.info(f\"path: {path}, bucket name: {bucketName}, awsPath: {awsPath}\")\n",
" s3.meta.client.upload_file(path, bucketName, awsPath)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ad46025-9ae8-46c7-8e23-5662b22526c3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def _create_and_upload_model_artifact(serving_properties_path: str,\n",
" bucket: str,\n",
" prefix: str) -> str:\n",
" \"\"\"\n",
" Create the model artifact with the updated serving properties within the directory\n",
" \"\"\"\n",
" # Create a tar.gz file containing only the serving.properties file\n",
" tar_file_path = os.path.join(Path(serving_properties_path).parent, 'model.tar.gz')\n",
" with tarfile.open(tar_file_path, \"w:gz\") as tar:\n",
" # Add the serving.properties file\n",
" tar.add(serving_properties_path, arcname='serving.properties')\n",
"\n",
" # Upload the tar.gz file to S3\n",
" key = f\"{prefix}/model.tar.gz\"\n",
" s3_client.upload_file(tar_file_path, bucket, key)\n",
" model_tar_gz_path: str = f\"s3://{bucket}/{key}\"\n",
" logger.info(f\"uploaded model.tar.gz to {model_tar_gz_path}\")\n",
" return model_tar_gz_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76a889de-2ad3-4f7a-990e-dbd634ffe062",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def _create_model(experiment_config: Dict,\n",
" inference_image_uri: str,\n",
" s3_model_artifact: str,\n",
" role_arn: str) -> Tuple[str, str]:\n",
" \"\"\"\n",
" # Function to create the SageMaker model\n",
" \"\"\"\n",
" model_name = name_from_base(experiment_config['model_name'])\n",
" env = experiment_config.get('env')\n",
"\n",
" # HF token required for gated model downloads form HF\n",
" hf_dict: Optional[Dict] = None\n",
" if HF_TOKEN is not None:\n",
" logger.info(f\"hf_token is provided, using it to create the model\")\n",
" hf_dict = dict(HUGGING_FACE_HUB_TOKEN=HF_TOKEN)\n",
" else:\n",
" logger.info(f\"hf_token not provided\")\n",
"\n",
" # this gets passed as an env var\n",
" if env:\n",
" if hf_dict:\n",
" # both env and hf_dict exists, so we do a union\n",
" env = env | hf_dict\n",
" else:\n",
" if hf_dict:\n",
" # env var did not exist, but hf_dict did so that\n",
" # is now the env var\n",
" env = hf_dict\n",
"\n",
" if env:\n",
" pc = dict(Image=inference_image_uri,\n",
" ModelDataUrl=s3_model_artifact,\n",
" Environment=env)\n",
" else:\n",
" pc = dict(Image=inference_image_uri,\n",
" ModelDataUrl=s3_model_artifact)\n",
"\n",
" create_model_response = sm_client.create_model(\n",
" ModelName=model_name,\n",
" ExecutionRoleArn=role_arn,\n",
" PrimaryContainer=pc,\n",
" )\n",
" return model_name, create_model_response[\"ModelArn\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "31eceecd-ec83-4d38-865c-f613e3efe8ef",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def _deploy_endpoint(experiment_config: Dict,\n",
" model_name: str) -> Tuple[str, str]:\n",
" \"\"\"\n",
" Function to create and deploy the endpoint\n",
" \"\"\"\n",
" endpoint_config_name = f\"{model_name}-config\"\n",
" endpoint_name = f\"{model_name}-endpoint\"\n",
"\n",
" _ = sm_client.create_endpoint_config(\n",
" EndpointConfigName=endpoint_config_name,\n",
" ProductionVariants=[\n",
" {\n",
" \"VariantName\": \"variant1\",\n",
" \"ModelName\": model_name,\n",
" \"InstanceType\": experiment_config[\"instance_type\"],\n",
" \"InitialInstanceCount\": 1,\n",
" \"ModelDataDownloadTimeoutInSeconds\": 3600,\n",
" \"ContainerStartupHealthCheckTimeoutInSeconds\": 3600,\n",
" },\n",
" ],\n",
" )\n",
"\n",
" create_endpoint_response = sm_client.create_endpoint(\n",
" EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name\n",
" )\n",
" return endpoint_name, create_endpoint_response['EndpointArn']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83c49d42-be08-4bf0-b415-1851594139a6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def _check_endpoint_status(endpoint_name: str) -> str:\n",
" \"\"\"\n",
" Function to check the status of the endpoint\n",
" \"\"\"\n",
" resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n",
" status = resp[\"EndpointStatus\"]\n",
" while status == \"Creating\":\n",
" time.sleep(60)\n",
" resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n",
" status = resp[\"EndpointStatus\"]\n",
" return status"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6e2bcbf-f3d8-49b7-96ab-c6f0f41297f7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def deploy_model_on_djl(model_config: Dict, serving_properties: Dict, role_arn: str) -> str:\n",
" \"\"\"\n",
" This function is used to deploy the model on the DJL container and create an endpoint.\n",
" This function takes in the following arguments:\n",
"\n",
" Args:\n",
" model_config: Contains the configuration of the model that needs to be deployed\n",
" role_arn: The role arn that is used to create the model that is used to create an endpoint\n",
"\n",
" Returns:\n",
" str: Returns a string which represents the endpoint name of the deployed model ready for \n",
" inference\n",
" \"\"\"\n",
" # download the model from hugging face into the model s3 path defined\n",
" if model_config.get(\"download_from_hf_place_in_s3\") is True:\n",
" with tempfile.TemporaryDirectory() as local_model_path:\n",
" logger.info(f\"created temporary directory {local_model_path}\")\n",
" local_model_path = _download_model(model_config['model_id'],\n",
" local_model_path)\n",
" logger.info(f\"going to upload model files to {model_config['model_s3_path']}\")\n",
"\n",
" o = urlparse(model_config['model_s3_path'], allow_fragments=False)\n",
" _upload_dir(local_model_path, o.path.lstrip('/'), o.netloc) \n",
" logger.info(f\"local model path: {local_model_path}, o.path: {o.path}, o.netloc: {o.netloc}\")\n",
"\n",
" model_artifact = model_config['model_s3_path']\n",
" logger.info(f\"Uncompressed model downloaded into ... -> {model_artifact}\")\n",
"\n",
" logger.info(\"preparing model artifact...\")\n",
"\n",
" # handle serving.properties, we read it from the config and then write it to\n",
" # a local file\n",
" logger.info(f\"write bucket for inserting model.tar.gz into: {write_model_to_s3_bucket}\")\n",
" properties = model_config[\"serving.properties\"]\n",
"\n",
" # create and upload the model.tar.gz, note that this file is just a placeholder\n",
" # it is not the actual model, the actual model binaries are in s3 or HuggingFace\n",
" # and the container will download them when the model endpoint is being created\n",
" logger.info(f\"uploading model.tar.gz to S3,bucket={write_model_to_s3_bucket}, \\\n",
" prefix={model_config['model_id']}\")\n",
" dir_path = os.getcwd()\n",
" serving_properties_path = os.path.join(dir_path, \"serving.properties\")\n",
" serving_props_str = '\\n'.join(f\"{key}={value}\" for key, value in properties.items())\n",
" Path(serving_properties_path).write_text(serving_props_str)\n",
" logger.info(f\"written the following serving.properties \\\n",
" content={properties} to {serving_properties_path}\")\n",
"\n",
" # create and upload the model.tar.gz, note that this file is just a placeholder\n",
" # it is not the actual model, the actual model binaries are in s3 or HuggingFace\n",
" # and the container will download them when the model endpoint is being created\n",
" logger.info(f\"uploading model.tar.gz to S3,bucket={write_model_to_s3_bucket}, \\\n",
" prefix={model_config['model_id']}\")\n",
" model_artifact = _create_and_upload_model_artifact(serving_properties_path,\n",
" write_model_to_s3_bucket,\n",
" model_config['model_id'])\n",
" # create and upload the model artifact to s3\n",
" model_artifact = _create_and_upload_model_artifact(serving_properties_path,\n",
" write_model_to_s3_bucket,\n",
" model_config['model_id'])\n",
" logger.info(f\"model uploaded to: {model_artifact}\")\n",
"\n",
" inference_image_uri = model_config['image_uri']\n",
" logger.info(f\"inference image URI: {inference_image_uri}\")\n",
"\n",
" # create model\n",
" model_name, model_arn = _create_model(model_config,\n",
" inference_image_uri,\n",
" model_artifact,\n",
" role_arn)\n",
" logger.info(f\"created Model: {model_arn}\")\n",
"\n",
" # deploy model\n",
" endpoint_name, _ = _deploy_endpoint(model_config, model_name)\n",
" logger.info(f\"deploying endpoint: {endpoint_name}\")\n",
"\n",
" # check model deployment status\n",
" status = _check_endpoint_status(endpoint_name)\n",
" logger.info(f\"Endpoint status: {status}\")\n",
"\n",
" if status == 'InService':\n",
" logger.info(\"endpoint is in service\")\n",
" else:\n",
" logger.info(\"endpoint is not in service.\")\n",
" return endpoint_name"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56976863-c086-4eff-acf7-05a9815fad00",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# deploy the model on\n",
"endpoint_name: str = deploy_model_on_djl(deploy_with_djl_llama2_13b, serving_properties, role)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc02976e-9d4f-4ebd-a735-f859ac9f80b5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"logger.info(f\"The endpoint is deployed: {endpoint_name}\")\n",
"response = sm_client.describe_endpoint(EndpointName=endpoint_name)\n",
"print(f\"Endpoint status: {response['EndpointStatus']}\")"
]
},
{
"cell_type": "markdown",
"id": "5dfd22bf-e93f-443a-a38f-ab3f32b8e277",
"metadata": {},
"source": [
"### Invoke the model endpoint for Inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72e1d94f-7ae8-4276-ace6-742cd4fea949",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from sagemaker.predictor import Predictor\n",
"from sagemaker.serializers import JSONSerializer\n",
"predictor = Predictor(\n",
" endpoint_name=endpoint_name,\n",
" sagemaker_session=sagemaker.Session(),\n",
" serializer=JSONSerializer()\n",
" )\n",
"predictor"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5fa21577-3b03-4366-8aa6-2528669240e7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"response_from_llama = predictor.predict({\"inputs\":\"What is the color of a rose\"})\n",
"response_from_llama"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_python3",
"language": "python",
"name": "conda_python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment