Created
November 24, 2024 09:45
-
-
Save zigelboim-misha/4c578bc70f82b16915abdb7dc189e737 to your computer and use it in GitHub Desktop.
`aws bedrock` models usage
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import boto3 | |
from typing import List | |
from datetime import datetime | |
from botocore.exceptions import ClientError | |
from llm.aws.tools import aws_bedrock_tools | |
from prompts.system import KUBERNETES_EXPERT_PROMPT, SYSTEM_TOOL_USAGE | |
LLAMA_3_1_MODEL_ID: str = "us.meta.llama3-1-70b-instruct-v1:0" | |
MISTRAL_LARGE_MODEL_ID: str = "mistral.mistral-large-2402-v1:0" | |
CLAUDE_3_5_HAIKU: str = "anthropic.claude-3-5-haiku-20241022-v1:0" | |
def create_client(service_name: str): | |
""" | |
Create a client for the specified service_name. | |
Args: | |
service_name (str): The name of the service to create a client for (bedrock). | |
Returns: | |
client (boto3.client): The client for the specified service_name. | |
""" | |
client = boto3.client( | |
service_name, | |
aws_access_key_id=AWS_KEY, | |
aws_secret_access_key=AWS_SECRET_KEY, | |
region_name="us-west-2") | |
return client | |
def list_all_foundation_models(): | |
""" | |
List all foundation models in the AWS Region. | |
""" | |
# Create a Bedrock Runtime client in the AWS Region you want to use. | |
client = create_client("bedrock") | |
try: | |
# This only works if we use "bedrock" as the client service_name. | |
response = client.list_foundation_models() | |
print(response) | |
except Exception as e: | |
print(f"Error listing all foundation models from aws bedrock: {e}") | |
def invoke_model(model_id: str, messages: List): | |
""" | |
Function to invoke the received model. | |
Args: | |
messages (List): List of messages to send to the model. | |
Returns: | |
response_body (dict): Response from the model. | |
""" | |
client = create_client("bedrock-runtime") | |
# Construct the prompt from messages | |
prompt = "\n".join( | |
[f"{msg['role']}: {msg['content']}" for msg in messages]) + "\nAI:" | |
body = json.dumps({ | |
"prompt": prompt | |
}) | |
try: | |
# Send the message to the model, using a basic inference configuration. | |
response = client.invoke_model( | |
modelId=model_id, | |
body=body | |
) | |
response_body = json.loads(response.get('body').read()) | |
except (ClientError, Exception) as e: | |
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}") | |
exit(1) | |
return response_body | |
def converse_with_model(model_id: str, messages: List): | |
""" | |
Function to converse with the received model. | |
Args: | |
messages (List): List of messages to send to the model. | |
Returns: | |
response_body (dict): Response from the model. | |
""" | |
client = create_client("bedrock-runtime") | |
chat_history = messages | |
chat_history.pop(0) # Remove the system message from the chat history | |
try: | |
# Send the message to the model, using a basic inference configuration. | |
response = client.converse( | |
modelId=model_id, | |
messages=messages, | |
system=messages[0]["content"], | |
inferenceConfig={ | |
"maxTokens": 500, | |
"temperature": 0.1 | |
}, | |
toolConfig=aws_bedrock_tools, | |
) | |
except (ClientError, Exception) as e: | |
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}") | |
exit(1) | |
return response | |
def test(): | |
# Invoke | |
messages = [] | |
messages.append( | |
{"role": "system", "content": KUBERNETES_EXPERT_PROMPT+SYSTEM_TOOL_USAGE}) | |
messages.append({"role": "user", "content": "List all namespaces please"}) | |
start_time = datetime.now() | |
res = invoke_model(model_id=LLAMA_3_1_MODEL_ID, messages=messages) | |
deration = datetime.now() - start_time | |
print(f"Duration: {deration}") | |
print(res) | |
# Converse | |
messages = [ | |
{"role": "system", "content": [{"text": KUBERNETES_EXPERT_PROMPT+SYSTEM_TOOL_USAGE}]}, | |
{"role": "user", "content": [{"text": "What is the most popular song on Radio XYZ?"}]} | |
] | |
start_time = datetime.now() | |
res = converse_with_model(model_id=LLAMA_3_1_MODEL_ID, messages=messages) | |
deration = datetime.now() - start_time | |
print(f"Duration: {deration}") | |
print(res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment