Last active
February 25, 2025 22:52
-
-
Save abacaj/9a567910c1a8663f7aa04520075e0ba8 to your computer and use it in GitHub Desktop.
extending GRPOTrainer to run gsm8k eval during training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tqdm | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import transformers | |
def extract_xml_answer(text: str) -> str: | |
answer = text.split("<final_answer>")[-1] | |
answer = answer.split("</final_answer>")[0] | |
return answer.strip() | |
def generate_gsm8k( | |
model, | |
tokenizer, | |
tokenized_samples, | |
batch_size, | |
max_completion_length | |
): | |
# run eval on main | |
if dist.get_rank() == 0: | |
device = model.device | |
predictions = [] | |
generation_config = transformers.GenerationConfig( | |
max_new_tokens=max_completion_length, | |
do_sample=False, | |
repetition_penalty=1.0, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
) | |
model.eval() | |
count = len(tokenized_samples) | |
status = tqdm.tqdm(tokenized_samples, desc=f"Correct: 0/{count}") | |
for i in range(0, count, batch_size): | |
batches = tokenized_samples[i:i+batch_size] | |
with torch.inference_mode(): | |
longest = max(len(b[0]) for b in batches) | |
# pad to longest on left side for decoder | |
padded_input_ids = torch.stack([ | |
torch.tensor([tokenizer.pad_token_id] * (longest - len(ids)) + ids) | |
for ids, _ in batches | |
]).to(device) | |
# ignore pad token when generating | |
attn_mask = torch.stack([ | |
tokens.ne(tokenizer.pad_token_id) for tokens in padded_input_ids | |
]).to(device) | |
output = model.generate( | |
input_ids=padded_input_ids, | |
attention_mask=attn_mask, | |
generation_config=generation_config, | |
) | |
for i, generated in enumerate(output): | |
response = tokenizer.decode( | |
generated[len(padded_input_ids[i]) :], skip_special_tokens=True | |
) | |
prediction = extract_xml_answer(response) | |
predictions.append(batches[i][1] == prediction) | |
status.update(batch_size) | |
status.set_description(f"Correct: {sum(predictions)}/{count}") | |
return np.mean(predictions) | |
return 0 | |
def tokenize_validation(tokenizer, samples, max_prompt_length): | |
tokenized_samples = [] | |
for sample in samples: | |
prompt = sample["prompt"] | |
answer = sample['answer'] | |
ids = tokenizer.apply_chat_template( | |
prompt, | |
add_generation_prompt=True, | |
truncation=False, | |
max_length=max_prompt_length, | |
) | |
tokenized_samples.append((ids, answer)) | |
return tokenized_samples | |
class CustomTrainer(transformers.GRPOTrainer): | |
def evaluate( | |
self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval" | |
): | |
tokenized_samples = tokenize_validation(self.processing_class, self.eval_dataset, self.args.max_prompt_length) | |
eval_acc = generate_gsm8k(self.model, self.processing_class, tokenized_samples, self.args.per_device_eval_batch_size, self.args.max_completion_length) | |
output = { | |
f"{metric_key_prefix}_accuracy": eval_acc, | |
"epoch": self.state.epoch, | |
} | |
self.log(output) | |
self.control = self.callback_handler.on_evaluate( | |
self.args, self.state, self.control, output | |
) | |
return output | |
training_args = transformers.GRPOConfig( | |
output_dir=f"checkpoints/qwen25-05b", | |
bf16=True, | |
max_prompt_length=356, | |
max_completion_length=512, | |
learning_rate=learning_rate, | |
... rest of config, | |
eval_steps=20, | |
per_device_eval_batch_size=256, # adjust based on your GPU! may cause oom error | |
do_eval=True, | |
eval_strategy="steps" | |
) | |
dataset = get_gsm8k_questions() | |
test_dataset = get_gsm8k_questions("test") | |
trainer = CustomTrainer( | |
model=model, | |
processing_class=tokenizer, | |
reward_funcs=[ | |
strict_format_reward_func, | |
int_reward_func, | |
correctness_reward_func, | |
], | |
args=training_args, | |
train_dataset=dataset, | |
eval_dataset=test_dataset, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This gist is an extension of: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb, you can copy the custom trainer portion