Skip to content

Instantly share code, notes, and snippets.

@twobob
Last active February 8, 2025 00:59
Show Gist options
  • Save twobob/7c3722fdb682c1f0b247d8b617502887 to your computer and use it in GitHub Desktop.
Save twobob/7c3722fdb682c1f0b247d8b617502887 to your computer and use it in GitHub Desktop.
import re
from datasets import load_dataset, Dataset
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
def extract_boxed_answer(text: str) -> str | None:
"""
Extracts the content inside the first occurrence of a LaTeX-style \boxed{...} in the given text.
Supports nested curly braces.
Returns the extracted answer as a string, or None if no properly balanced \boxed{...} is found.
"""
start_marker = r'\boxed{'
start_idx = text.find(start_marker)
if start_idx == -1:
return None
# Position immediately after the opening brace of \boxed{
content_start = start_idx + len(start_marker)
brace_count = 1
i = content_start
# Traverse the text to find the matching closing brace, accounting for nested braces
while i < len(text) and brace_count > 0:
char = text[i]
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
i += 1
# If the braces are unbalanced, return None
if brace_count != 0:
return None
# Extract the content inside the outermost \boxed{...}
answer = text[content_start:i - 1]
return answer.strip()
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('AI-MO/NuminaMath-CoT', 'default')[split] # type: ignore
one_percent_index = int(0.01 * len(data))
# Select only the first 10% of the data
data = data.select(range(one_percent_index)) # type: ignore
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['problem']}
],
'answer': extract_boxed_answer(x['solution'])
}) # type: ignore
data = data.remove_columns(["source", "messages"])
return data # type: ignore
dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]