Last active
February 6, 2024 06:56
-
-
Save kykim0/d9551feaa1f016c65ddd5ab76b63df8a to your computer and use it in GitHub Desktop.
eval_reward.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Process inference output files.""" | |
| from collections import defaultdict | |
| import csv | |
| import glob | |
| import json | |
| import os | |
| from fastchat.llm_judge.common import load_questions | |
| from fastchat.model import get_conversation_template | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from transformers import LlamaTokenizer | |
| from alignment import ( | |
| LlamaRewardModel, | |
| get_kbit_device_map, | |
| get_tokenizer, | |
| ) | |
| def process_local(csv_fname): | |
| eval_rewards = [] | |
| seen_queries = set() | |
| with open(csv_fname, newline='') as csvfile: | |
| reader = csv.DictReader(csvfile) | |
| for row in reader: | |
| query = row["query"] | |
| if query in seen_queries: continue | |
| seen_queries.add(query) | |
| eval_rewards.append(float(row["reward"])) | |
| return eval_rewards | |
| def process_mtbench(rmodel_tokenizer, rmodel, questions, out_fname): | |
| answers = {} | |
| with open(out_fname, "r") as fin: | |
| for l in fin: | |
| json_loaded = json.loads(l) | |
| qid = json_loaded["question_id"] | |
| answers[qid] = json_loaded | |
| mtbench_rewards = defaultdict(list) | |
| for question in tqdm(questions): | |
| assert(len(question["turns"]) == 2) | |
| qid = question["question_id"] | |
| responses = answers[qid]["choices"][0] | |
| conv = get_conversation_template("zephyr") | |
| for j in range(len(question["turns"])): | |
| qs = question["turns"][j] | |
| resp = responses["turns"][j] | |
| conv.append_message(conv.roles[0], qs) | |
| conv.append_message(conv.roles[1], resp) | |
| prompt = conv.get_prompt() | |
| with torch.no_grad(): | |
| rmodel_inputs = rmodel_tokenizer([prompt], padding=True, truncation=True, return_tensors="pt") | |
| rewards = [r.squeeze().to(dtype=torch.float) for r in rmodel(**rmodel_inputs)] | |
| assert(len(rewards) == 1) | |
| mtbench_rewards[j].append(rewards[0]) | |
| return mtbench_rewards | |
| def reward_model(): | |
| rmodel_name = "openbmb/UltraRM-13b" | |
| rmodel_tokenizer = LlamaTokenizer.from_pretrained(rmodel_name) | |
| rmodel = LlamaRewardModel.from_pretrained( | |
| rmodel_name, | |
| load_in_8bit=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map=get_kbit_device_map(), | |
| ) | |
| rmodel.eval() | |
| if getattr(rmodel_tokenizer, "pad_token", None) is None: | |
| rmodel_tokenizer.pad_token = rmodel_tokenizer.eos_token | |
| rmodel.config.pad_token_id = rmodel_tokenizer.eos_token_id | |
| return rmodel_tokenizer, rmodel | |
| def main(): | |
| base_dir = "./save/ppo-trl-tune" | |
| exp_dirs = [fname for fname in glob.glob(os.path.join(base_dir, "*")) if not fname.endswith("tmp")] | |
| bench_name = "mt_bench" | |
| question_file = f"./eval/FastChat/fastchat/llm_judge/data/{bench_name}/question.jsonl" | |
| questions = load_questions(question_file, None, None) | |
| rmodel_tokenizer, rmodel = reward_model() | |
| for exp_dir in exp_dirs: | |
| if ("lr9e-06-kl0.2-gam1-lam0.95-vf0.1-b64-e4" not in exp_dir and | |
| "lr1e-05-kl0.2-gam1-lam0.95-vf0.1-b64-e2" not in exp_dir): | |
| continue | |
| local_eval_dict = {} | |
| for eval_csv in sorted(glob.glob(os.path.join(exp_dir, "eval_local_*.csv"))): | |
| eval_rewards = process_local(eval_csv) | |
| local_eval_dict[os.path.basename(eval_csv)] = round(np.mean(eval_rewards), 3) | |
| # result_str = " | ".join(f"{k}: {v}" for k, v in sorted(local_eval_dict.items())) | |
| # print(f"{os.path.basename(exp_dir)}: {result_str}") | |
| mtbench_eval_dict = {} | |
| for eval_fname in sorted(glob.glob(os.path.join(exp_dir, "eval_mt-bench_*.jsonl"))): | |
| eval_rewards = process_mtbench(rmodel_tokenizer, rmodel, questions, eval_fname) | |
| rewards = [(f"turn{idx}", round(np.mean(turn_rewards), 3)) | |
| for idx, turn_rewards in eval_rewards.items()] | |
| mtbench_eval_dict[os.path.basename(eval_fname)] = rewards | |
| print(f"Experiment {exp_dir}") | |
| print(f"Local eval: {local_eval_dict}") | |
| print(f"Mt-bench eval: {mtbench_eval_dict}\n") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment