This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """PPO v2 trainer.""" | |
| import logging | |
| import random | |
| from accelerate import PartialState | |
| from datasets import load_dataset | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import shutil | |
| from accelerate import PartialState | |
| from datasets import load_dataset | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| HfArgumentParser, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Conversation prompt templates. | |
| We kindly request that you import fastchat instead of copying this file if you wish to use it. | |
| If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. | |
| """ | |
| import base64 | |
| import dataclasses | |
| from enum import auto, IntEnum |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| gemma-7b-sft: | |
| prompt_template: "gemma-7b-sft/prompt.txt" | |
| fn_completions: "huggingface_local_completions" | |
| completions_kwargs: | |
| model_name: "kykim0/gemma-7b-ultrachat-sft" | |
| model_kwargs: | |
| torch_dtype: 'bfloat16' | |
| max_new_tokens: 512 | |
| temperature: 0.7 | |
| top_p: 1.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class RewardTrainer(Trainer): | |
| r""" | |
| The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the | |
| `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use | |
| an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset | |
| of paired examples, where each example is a tuple of two sequences. The reward model should be trained to | |
| predict which example in the pair is more relevant to the task at hand. | |
| The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least | |
| if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Reward modeling on preference data. | |
| from collections import defaultdict | |
| import logging | |
| import os | |
| from random import sample | |
| import sys | |
| from alignment import ( | |
| DataArguments, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Process inference output files.""" | |
| from collections import defaultdict | |
| import csv | |
| import glob | |
| import json | |
| import os | |
| from fastchat.llm_judge.common import load_questions | |
| from fastchat.model import get_conversation_template |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Process inference output files.""" | |
| from collections import defaultdict | |
| import csv | |
| import glob | |
| import json | |
| import os | |
| from fastchat.llm_judge.common import load_questions | |
| from fastchat.model import get_conversation_template |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import os | |
| import random | |
| import time | |
| import shortuuid | |
| import torch | |
| from tqdm import tqdm | |
| from fastchat.llm_judge.common import load_questions, temperature_config |