Skip to content

Instantly share code, notes, and snippets.

@philschmid
Last active May 15, 2023 06:38
Show Gist options
  • Save philschmid/2ef73610b70c6ed8605a23f74c914a69 to your computer and use it in GitHub Desktop.
Save philschmid/2ef73610b70c6ed8605a23f74c914a69 to your computer and use it in GitHub Desktop.
{
"code": "import matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom sagemaker import image_uris, model_uris, script_uris, instance_types\r\nfrom sagemaker.predictor import Predictor\r\nfrom sagemaker import get_execution_role\r\nimport json\r\n\r\n\r\nmodel_id, model_version = \"model-txt2img-stabilityai-stable-diffusion-v2-fp16\", \"*\"\r\n\r\ninference_instance_type = instance_types.retrieve_default(\r\n model_id=model_id, model_version=model_version, scope=\"inference\"\r\n)\r\n\r\n# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.\r\ndeploy_image_uri = image_uris.retrieve(\r\n region=None,\r\n framework=None, # automatically inferred from model_id\r\n image_scope=\"inference\",\r\n model_id=model_id,\r\n model_version=model_version,\r\n instance_type=inference_instance_type,\r\n)\r\n\r\n# Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.\r\ndeploy_source_uri = script_uris.retrieve(\r\n model_id=model_id, model_version=model_version, script_scope=\"inference\"\r\n)\r\n\r\n\r\n# Retrieve the model uri. This includes the pre-trained nvidia-ssd model and parameters.\r\nmodel_uri = model_uris.retrieve(\r\n model_id=model_id, model_version=model_version, model_scope=\"inference\"\r\n)\r\n\r\n# To increase the maximum response size from the endpoint.\r\nenv = {\r\n \"MMS_MAX_RESPONSE_SIZE\": \"20000000\",\r\n}\r\n\r\n# Create the SageMaker model instance\r\nmodel = Model(\r\n image_uri=deploy_image_uri,\r\n source_dir=deploy_source_uri,\r\n model_data=model_uri,\r\n entry_point=\"inference.py\", # entry point file in source_dir and present in deploy_source_uri\r\n role=get_execution_role(),\r\n predictor_cls=Predictor,\r\n name=endpoint_name,\r\n env=env,\r\n)\r\n\r\n# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,\r\n# for being able to run inference through the sagemaker API.\r\nmodel_predictor = model.deploy(\r\n initial_instance_count=1,\r\n instance_type=inference_instance_type,\r\n predictor_cls=Predictor,\r\n endpoint_name=endpoint_name,\r\n)\r\n\r\n\r\ndef query(model_predictor, text):\r\n \"\"\"Query the model predictor.\"\"\"\r\n\r\n encoded_text = json.dumps(text).encode(\"utf-8\")\r\n\r\n query_response = model_predictor.predict(\r\n encoded_text,\r\n {\r\n \"ContentType\": \"application\/x-text\",\r\n \"Accept\": \"application\/json\",\r\n },\r\n )\r\n return query_response\r\n\r\n\r\ndef parse_response(query_response):\r\n \"\"\"Parse response and return generated image and the prompt\"\"\"\r\n\r\n response_dict = json.loads(query_response)\r\n return response_dict[\"generated_image\"], response_dict[\"prompt\"]\r\n\r\n\r\ndef display_img_and_prompt(img, prmpt):\r\n \"\"\"Display hallucinated image.\"\"\"\r\n plt.figure(figsize=(12, 12))\r\n plt.imshow(np.array(img))\r\n plt.axis(\"off\")\r\n plt.title(prmpt)\r\n plt.show()\r\n",
"links": [{
"label": "SageMaker JumpStart Examples",
"href": "https://github.com/aws/amazon-sagemaker-examples/tree/main/sagemaker-jumpstart"
}]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment