Last active
February 8, 2025 00:59
-
-
Save twobob/7c3722fdb682c1f0b247d8b617502887 to your computer and use it in GitHub Desktop.
showing the edits to use AI-MO/NuminaMath-CoT for the data prep for https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Phi_4_(14B)-GRPO.ipynb#scrollTo=cXk993X6C2ZZ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from datasets import load_dataset, Dataset | |
# Load and prep dataset | |
SYSTEM_PROMPT = """ | |
Respond in the following format: | |
<reasoning> | |
... | |
</reasoning> | |
<answer> | |
... | |
</answer> | |
""" | |
XML_COT_FORMAT = """\ | |
<reasoning> | |
{reasoning} | |
</reasoning> | |
<answer> | |
{answer} | |
</answer> | |
""" | |
def extract_xml_answer(text: str) -> str: | |
answer = text.split("<answer>")[-1] | |
answer = answer.split("</answer>")[0] | |
return answer.strip() | |
def extract_hash_answer(text: str) -> str | None: | |
if "####" not in text: | |
return None | |
return text.split("####")[1].strip() | |
def extract_boxed_answer(text: str) -> str | None: | |
""" | |
Extracts the content inside the first occurrence of a LaTeX-style \boxed{...} in the given text. | |
Supports nested curly braces. | |
Returns the extracted answer as a string, or None if no properly balanced \boxed{...} is found. | |
""" | |
start_marker = r'\boxed{' | |
start_idx = text.find(start_marker) | |
if start_idx == -1: | |
return None | |
# Position immediately after the opening brace of \boxed{ | |
content_start = start_idx + len(start_marker) | |
brace_count = 1 | |
i = content_start | |
# Traverse the text to find the matching closing brace, accounting for nested braces | |
while i < len(text) and brace_count > 0: | |
char = text[i] | |
if char == '{': | |
brace_count += 1 | |
elif char == '}': | |
brace_count -= 1 | |
i += 1 | |
# If the braces are unbalanced, return None | |
if brace_count != 0: | |
return None | |
# Extract the content inside the outermost \boxed{...} | |
answer = text[content_start:i - 1] | |
return answer.strip() | |
# uncomment middle messages for 1-shot prompting | |
def get_gsm8k_questions(split = "train") -> Dataset: | |
data = load_dataset('AI-MO/NuminaMath-CoT', 'default')[split] # type: ignore | |
one_percent_index = int(0.01 * len(data)) | |
# Select only the first 10% of the data | |
data = data.select(range(one_percent_index)) # type: ignore | |
data = data.map(lambda x: { # type: ignore | |
'prompt': [ | |
{'role': 'system', 'content': SYSTEM_PROMPT}, | |
{'role': 'user', 'content': x['problem']} | |
], | |
'answer': extract_boxed_answer(x['solution']) | |
}) # type: ignore | |
data = data.remove_columns(["source", "messages"]) | |
return data # type: ignore | |
dataset = get_gsm8k_questions() | |
# Reward functions | |
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
q = prompts[0][-1]['content'] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] | |
def int_reward_func(completions, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] | |
def strict_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def soft_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def count_xml(text) -> float: | |
count = 0.0 | |
if text.count("<reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n</reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n<answer>\n") == 1: | |
count += 0.125 | |
count -= len(text.split("\n</answer>\n")[-1])*0.001 | |
if text.count("\n</answer>") == 1: | |
count += 0.125 | |
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001 | |
return count | |
def xmlcount_reward_func(completions, **kwargs) -> list[float]: | |
contents = [completion[0]["content"] for completion in completions] | |
return [count_xml(c) for c in contents] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Phi_4_(14B)-GRPO.ipynb#scrollTo=SSlkA49z2xZB&line=4&uniqifier=1