Skip to content

Instantly share code, notes, and snippets.

@keyboardAnt
Created November 8, 2024 18:27
Show Gist options
  • Save keyboardAnt/33044ddf4e475db789b81de9ca09f91b to your computer and use it in GitHub Desktop.
Save keyboardAnt/33044ddf4e475db789b81de9ca09f91b to your computer and use it in GitHub Desktop.
import numpy as np
# Placeholder functions to be implemented
def get_lora_adapter(username: str) -> callable:
"""Returns a callable function for the Lora adapter."""
pass
def get_reft_adapter(username: str) -> callable:
"""Returns a callable function for the reft adapter."""
pass
def get_prompted_base(username: str) -> callable:
"""Returns a callable function for the prompted base adapter."""
pass
def get_random_chats(username: str, num_of_chats: int) -> list[dict]:
"""Fetches random chat data for the given username."""
pass
def get_random_prompt(chat: dict) -> str:
"""Extracts a random prompt from the chat where the next message is from the username.
The chat should include the username, and the prompt should be a cropped version of the chat
with the next message being a random message by the user."""
pass
def get_perplexity(model_fn: callable, prompt: str) -> float:
"""Calculates and returns the perplexity of the model's response to the prompt."""
pass
def normalize_st_sum_equals_one(perplexities: np.ndarray) -> np.ndarray:
"""Normalizes the array of perplexities such that the sum equals one."""
return perplexities / perplexities.sum()
def get_llm_as_a_judge_winner(prompt: str, candidate_fns: list[callable]) -> int:
"""Determines the winning model based on some judgment criteria for the given prompt."""
pass
# Candidate functions
candidate_fns = [
get_lora_adapter_fn,
get_reft_adapter_fn,
get_prompted_base_fn
]
# Initialize tracking variables
candidate_perplexity = np.zeros(len(candidate_fns))
candidate_wins = [0] * len(candidate_fns)
# List of usernames to evaluate
usernames = ["user1", "user2", "user3"]
# Main evaluation loop
for username in usernames:
chats: list[dict] = get_random_chats(username, num_of_chats=50)
for chat in chats:
prompt = get_random_prompt(chat)
# Calculate perplexity for each candidate
perplexities = [get_perplexity(fn(username), prompt) for fn in candidate_fns]
# Normalize perplexities
normalized_perplexity = normalize_st_sum_equals_one(np.array(perplexities))
# Update cumulative perplexity tracking
candidate_perplexity += normalized_perplexity
# Determine the winner and update the win count
winner_index = get_llm_as_a_judge_winner(prompt, [fn(username) for fn in candidate_fns])
candidate_wins[winner_index] += 1
# The `candidate_perplexity` array and `candidate_wins` list now hold the evaluation results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment