Created
November 8, 2024 18:27
-
-
Save keyboardAnt/33044ddf4e475db789b81de9ca09f91b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# Placeholder functions to be implemented | |
def get_lora_adapter(username: str) -> callable: | |
"""Returns a callable function for the Lora adapter.""" | |
pass | |
def get_reft_adapter(username: str) -> callable: | |
"""Returns a callable function for the reft adapter.""" | |
pass | |
def get_prompted_base(username: str) -> callable: | |
"""Returns a callable function for the prompted base adapter.""" | |
pass | |
def get_random_chats(username: str, num_of_chats: int) -> list[dict]: | |
"""Fetches random chat data for the given username.""" | |
pass | |
def get_random_prompt(chat: dict) -> str: | |
"""Extracts a random prompt from the chat where the next message is from the username. | |
The chat should include the username, and the prompt should be a cropped version of the chat | |
with the next message being a random message by the user.""" | |
pass | |
def get_perplexity(model_fn: callable, prompt: str) -> float: | |
"""Calculates and returns the perplexity of the model's response to the prompt.""" | |
pass | |
def normalize_st_sum_equals_one(perplexities: np.ndarray) -> np.ndarray: | |
"""Normalizes the array of perplexities such that the sum equals one.""" | |
return perplexities / perplexities.sum() | |
def get_llm_as_a_judge_winner(prompt: str, candidate_fns: list[callable]) -> int: | |
"""Determines the winning model based on some judgment criteria for the given prompt.""" | |
pass | |
# Candidate functions | |
candidate_fns = [ | |
get_lora_adapter_fn, | |
get_reft_adapter_fn, | |
get_prompted_base_fn | |
] | |
# Initialize tracking variables | |
candidate_perplexity = np.zeros(len(candidate_fns)) | |
candidate_wins = [0] * len(candidate_fns) | |
# List of usernames to evaluate | |
usernames = ["user1", "user2", "user3"] | |
# Main evaluation loop | |
for username in usernames: | |
chats: list[dict] = get_random_chats(username, num_of_chats=50) | |
for chat in chats: | |
prompt = get_random_prompt(chat) | |
# Calculate perplexity for each candidate | |
perplexities = [get_perplexity(fn(username), prompt) for fn in candidate_fns] | |
# Normalize perplexities | |
normalized_perplexity = normalize_st_sum_equals_one(np.array(perplexities)) | |
# Update cumulative perplexity tracking | |
candidate_perplexity += normalized_perplexity | |
# Determine the winner and update the win count | |
winner_index = get_llm_as_a_judge_winner(prompt, [fn(username) for fn in candidate_fns]) | |
candidate_wins[winner_index] += 1 | |
# The `candidate_perplexity` array and `candidate_wins` list now hold the evaluation results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment