Skip to content

Instantly share code, notes, and snippets.

@abacaj
Last active February 25, 2025 22:52
Show Gist options
  • Save abacaj/9a567910c1a8663f7aa04520075e0ba8 to your computer and use it in GitHub Desktop.
Save abacaj/9a567910c1a8663f7aa04520075e0ba8 to your computer and use it in GitHub Desktop.
extending GRPOTrainer to run gsm8k eval during training
import tqdm
import numpy as np
import torch
import torch.distributed as dist
import transformers
def extract_xml_answer(text: str) -> str:
answer = text.split("<final_answer>")[-1]
answer = answer.split("</final_answer>")[0]
return answer.strip()
def generate_gsm8k(
model,
tokenizer,
tokenized_samples,
batch_size,
max_completion_length
):
# run eval on main
if dist.get_rank() == 0:
device = model.device
predictions = []
generation_config = transformers.GenerationConfig(
max_new_tokens=max_completion_length,
do_sample=False,
repetition_penalty=1.0,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
model.eval()
count = len(tokenized_samples)
status = tqdm.tqdm(tokenized_samples, desc=f"Correct: 0/{count}")
for i in range(0, count, batch_size):
batches = tokenized_samples[i:i+batch_size]
with torch.inference_mode():
longest = max(len(b[0]) for b in batches)
# pad to longest on left side for decoder
padded_input_ids = torch.stack([
torch.tensor([tokenizer.pad_token_id] * (longest - len(ids)) + ids)
for ids, _ in batches
]).to(device)
# ignore pad token when generating
attn_mask = torch.stack([
tokens.ne(tokenizer.pad_token_id) for tokens in padded_input_ids
]).to(device)
output = model.generate(
input_ids=padded_input_ids,
attention_mask=attn_mask,
generation_config=generation_config,
)
for i, generated in enumerate(output):
response = tokenizer.decode(
generated[len(padded_input_ids[i]) :], skip_special_tokens=True
)
prediction = extract_xml_answer(response)
predictions.append(batches[i][1] == prediction)
status.update(batch_size)
status.set_description(f"Correct: {sum(predictions)}/{count}")
return np.mean(predictions)
return 0
def tokenize_validation(tokenizer, samples, max_prompt_length):
tokenized_samples = []
for sample in samples:
prompt = sample["prompt"]
answer = sample['answer']
ids = tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
truncation=False,
max_length=max_prompt_length,
)
tokenized_samples.append((ids, answer))
return tokenized_samples
class CustomTrainer(transformers.GRPOTrainer):
def evaluate(
self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"
):
tokenized_samples = tokenize_validation(self.processing_class, self.eval_dataset, self.args.max_prompt_length)
eval_acc = generate_gsm8k(self.model, self.processing_class, tokenized_samples, self.args.per_device_eval_batch_size, self.args.max_completion_length)
output = {
f"{metric_key_prefix}_accuracy": eval_acc,
"epoch": self.state.epoch,
}
self.log(output)
self.control = self.callback_handler.on_evaluate(
self.args, self.state, self.control, output
)
return output
training_args = transformers.GRPOConfig(
output_dir=f"checkpoints/qwen25-05b",
bf16=True,
max_prompt_length=356,
max_completion_length=512,
learning_rate=learning_rate,
... rest of config,
eval_steps=20,
per_device_eval_batch_size=256, # adjust based on your GPU! may cause oom error
do_eval=True,
eval_strategy="steps"
)
dataset = get_gsm8k_questions()
test_dataset = get_gsm8k_questions("test")
trainer = CustomTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args=training_args,
train_dataset=dataset,
eval_dataset=test_dataset,
)
@abacaj
Copy link
Author

abacaj commented Feb 5, 2025

This gist is an extension of: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb, you can copy the custom trainer portion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment