-
-
Save allanj/0d6fc89e809ceb99e2a4efb067b1280f to your computer and use it in GitHub Desktop.
client.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import openai | |
import asyncio | |
async def get_choice_completion(prompt, choices): | |
# Initialize an asynchronous OpenAI client | |
async with openai.AsyncClient(base_url="http://127.0.0.1:8000/v1", api_key="abc") as client: | |
choice_probs = {} | |
# Calculate logprobs for each prompt + choice sequence | |
for choice in choices: | |
choice_prompt = prompt + choice | |
response = await client.completions.create( | |
model="meta-llama/Meta-Llama-3-8B", # Specify the model | |
prompt=choice_prompt, | |
echo=True, | |
logprobs=0, # Request logprobs to get likelihoods for each token | |
max_tokens=0, # Set to 0 to prevent additional tokens from being generated | |
) | |
# Sum the logprobs for the tokens in the prompt part of the response | |
prob = sum([x for x in response.choices[0].logprobs.token_logprobs if x is not None]) | |
choice_probs[choice] = prob | |
# Select the choice with the highest probability (highest summed logprobs) | |
best_choice = max(choice_probs, key=choice_probs.get) | |
return best_choice, choice_probs | |
# Define main async function to test it | |
async def main(): | |
inputs = [ | |
["What color is the sky on a clear day? The color is ", ["blue", "green", "red", "yellow"]], | |
["The capital of France is ", ["Paris", "Beijing"]], | |
] | |
tasks = [get_choice_completion(prompt, choices) for prompt, choices in inputs] | |
results = await asyncio.gather(*tasks) | |
return results | |
# Run the main function | |
data = asyncio.run(main()) | |
print(data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment