sudo apt update && sudo apt upgrade
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# train_grpo.py | |
# | |
# See https://github.com/willccbb/verifiers for ongoing developments | |
# | |
import re | |
import torch | |
from datasets import load_dataset, Dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import LoraConfig | |
from trl import GRPOConfig, GRPOTrainer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# Copyright 2023 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software |
Maybe you've heard about this technique but you haven't completely understood it, especially the PPO part. This explanation might help.
We will focus on text-to-text language models 📝, such as GPT-3, BLOOM, and T5. Models like BERT, which are encoder-only, are not addressed.
Reinforcement Learning from Human Feedback (RLHF) has been successfully applied in ChatGPT, hence its major increase in popularity. 📈
RLHF is especially useful in two scenarios 🌟:
- You can’t create a good loss function
- Example: how do you calculate a metric to measure if the model’s output was funny?
- You want to train with production data, but you can’t easily label your production data