Last active
May 15, 2023 06:38
-
-
Save philschmid/2ef73610b70c6ed8605a23f74c914a69 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"code": "import matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom sagemaker import image_uris, model_uris, script_uris, instance_types\r\nfrom sagemaker.predictor import Predictor\r\nfrom sagemaker import get_execution_role\r\nimport json\r\n\r\n\r\nmodel_id, model_version = \"model-txt2img-stabilityai-stable-diffusion-v2-fp16\", \"*\"\r\n\r\ninference_instance_type = instance_types.retrieve_default(\r\n model_id=model_id, model_version=model_version, scope=\"inference\"\r\n)\r\n\r\n# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.\r\ndeploy_image_uri = image_uris.retrieve(\r\n region=None,\r\n framework=None, # automatically inferred from model_id\r\n image_scope=\"inference\",\r\n model_id=model_id,\r\n model_version=model_version,\r\n instance_type=inference_instance_type,\r\n)\r\n\r\n# Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.\r\ndeploy_source_uri = script_uris.retrieve(\r\n model_id=model_id, model_version=model_version, script_scope=\"inference\"\r\n)\r\n\r\n\r\n# Retrieve the model uri. This includes the pre-trained nvidia-ssd model and parameters.\r\nmodel_uri = model_uris.retrieve(\r\n model_id=model_id, model_version=model_version, model_scope=\"inference\"\r\n)\r\n\r\n# To increase the maximum response size from the endpoint.\r\nenv = {\r\n \"MMS_MAX_RESPONSE_SIZE\": \"20000000\",\r\n}\r\n\r\n# Create the SageMaker model instance\r\nmodel = Model(\r\n image_uri=deploy_image_uri,\r\n source_dir=deploy_source_uri,\r\n model_data=model_uri,\r\n entry_point=\"inference.py\", # entry point file in source_dir and present in deploy_source_uri\r\n role=get_execution_role(),\r\n predictor_cls=Predictor,\r\n name=endpoint_name,\r\n env=env,\r\n)\r\n\r\n# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,\r\n# for being able to run inference through the sagemaker API.\r\nmodel_predictor = model.deploy(\r\n initial_instance_count=1,\r\n instance_type=inference_instance_type,\r\n predictor_cls=Predictor,\r\n endpoint_name=endpoint_name,\r\n)\r\n\r\n\r\ndef query(model_predictor, text):\r\n \"\"\"Query the model predictor.\"\"\"\r\n\r\n encoded_text = json.dumps(text).encode(\"utf-8\")\r\n\r\n query_response = model_predictor.predict(\r\n encoded_text,\r\n {\r\n \"ContentType\": \"application\/x-text\",\r\n \"Accept\": \"application\/json\",\r\n },\r\n )\r\n return query_response\r\n\r\n\r\ndef parse_response(query_response):\r\n \"\"\"Parse response and return generated image and the prompt\"\"\"\r\n\r\n response_dict = json.loads(query_response)\r\n return response_dict[\"generated_image\"], response_dict[\"prompt\"]\r\n\r\n\r\ndef display_img_and_prompt(img, prmpt):\r\n \"\"\"Display hallucinated image.\"\"\"\r\n plt.figure(figsize=(12, 12))\r\n plt.imshow(np.array(img))\r\n plt.axis(\"off\")\r\n plt.title(prmpt)\r\n plt.show()\r\n", | |
"links": [{ | |
"label": "SageMaker JumpStart Examples", | |
"href": "https://github.com/aws/amazon-sagemaker-examples/tree/main/sagemaker-jumpstart" | |
}] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment