Skip to content

Instantly share code, notes, and snippets.

@alreadydone
Forked from yoavg/rl-for-llms.md
Created April 24, 2023 03:11
Show Gist options
  • Save alreadydone/c0e315fbd937dbecef66396c493f0572 to your computer and use it in GitHub Desktop.
Save alreadydone/c0e315fbd937dbecef66396c493f0572 to your computer and use it in GitHub Desktop.

Reinforcement Learning for Language Models

Yoav Goldberg, April 2023.

Why RL?

With the release of the ChatGPT model and followup large language models (LLMs), there was a lot of discussion of the importance of "RLHF training", that is, "reinforcement learning from human feedback". I was puzzled for a while as to why RL (Reinforcement Learning) is better than learning from demonstrations (a.k.a supervised learning) for training language models. Shouldn't learning from demonstrations (or, in language model terminology "instruction fine tuning", learning to immitate human written answers) be sufficient? I came up with a theoretical argument that was somewhat convincing. But I came to realize there is an additional argumment which not only supports the case of RL training, but also requires it, in particular for models like ChatGPT. This additional argument is spelled out in (the first half of) a talk by John Schulman from OpenAI. This post pretty much repeats his argument in more words, and also adds some things that John did not say explicitly (but which I am sure he thought about).

I included quite a bit of background to be self contained. You can skip directly to "The core argument" to get directly to the main dish.

Background: Supervised learing vs RL

Let's briefly explain the two learning scenarios so we are on the same page. Feel free to skip this part if you are "in the know".

Pre-training In both settings, we assume a language model is first pre-trained over large body of text, with the objective of predicing the next token. So we have a model that, for every sequence of words, can assign a probability over the options of the potential next word. By doing so, it also acquires some sort of an internal representation of the language. After this process, the model is very capable of generating texts, and providing natural continuations to given text prefixes, but it is not good at "communicating". For example, when prompted with a question, it might either answer it OR it might generate a series of additional questions, OR it might say this is an important question that was raised in the context of ..., etc. All of these are valid continuation that follow questions in natural language texts. We can make the model perform desired language actions by crafting input texts in a way that their continuation will solve our problem (this is called "prompt engineering"), but this is not a very convenient interaction mode for non-expert users, who just want to ask a question or provide an instruction, and let the model follow it. If we want a model that is capable of consistently answering queries and not only completing them, we need to guide it towards this behavior. This guidance is called "fine tuning": continue to train the pre-trained model so that it behaves as we would like. (Some people call this "aligning" the model with a desired behavior)

Suprvised Training In the supervised learning case (also called learning from demonstrations, or "instruction tuning") we collect a bunch of human authored texts that have a form of a question or an instruction, followed by the desired output. For example, these texts can be a question followed by its answer, or a task such as summarize the following text {text} followed by its human authored summary. By continuing to train the model on the same "predict the next token given the prefix" objective, but this time on this collection of instruction-output pairs, the model learns to respond to instructions by performing them. That is, the model receives demonstrations of what a correct output for a given query is, and learns to replicate this output. We hope that it would generalize this behavior to other queries, not seen in training.

Reinforcement Learning (RL) In the reinforcement learning setup, we provide the model with the instructions, but not with their human authored answers. Instead, the model should generate its own answer. A scoring mechanism (for example, a human) reads the generated answers, and tells the model if its good or not. The model's objective is to learn how to answer in a way that recieves high scores. An alternative mechanism is that the model generates several answers, and the scoring mechanism tells the model which one is the best. The model's objective is to learn to produce the higher scoring answers and not the lower scoring ones. In both cases, the model learns by creating its own answer, and receiving a feedback. (note: many researchers consider RL more narrowly, based on some techincal aspects of the credit assignment mechanism, and for them the question "do we need RL" may boil down to should we use this family of techniques or some alternative family. I share their curiosity, but for the purpose of this post I consider any method that uses an external scoring function as RL, regardless of its mechanics.)

RL is much harder than supervised training for several reasons. One such reason is "credit assignment". The language model generates a sequence of tokens, and only gets a score at the end of the sequence. The signal is weak: which parts of the answer are good and which are bad? A lot of technical works in RL attempts to solve this problem, and we put it aside for this post. It is an active research area, but reasonable solutions exist. The other issue is that we need a scoring mechanism to score the answers (either assign a score or compare two answers) and, in the context of language-based tasks, it is hard to come up with an automatic scorer (though that might be changing, as I briefly discuss below). Thus, we are left with "human feedback" for each learning step. This is very expensive and also slow, and the problem is even worse given that each human feedback only gives a rather sparse signal, as we just saw above. Given these difficulties, why should we use RL and not just supervised learning?

The diversity argument

Perhaps the most intuitive argument against superivsed learning / instruction tuning in the context of language generation models is that we teach the learner to replicate the exact answer given by the demonstrator, while the reality of human language is that there are many different ways to convey the same message, and they all might be valid answers. We "punish" the model for even slight deviations from our prescribed text, which may confuse it. We may also insist on a phrasing which is hard for the model to learn, while the model already knows how to produce an alternative---and equally valid---answer. We would thus like the diversity afforded by RL training. This is a very intuitive argument, but not a very convincing one, given that supervised learning does seem to work very well in practice, and given the challenges in training RL models. For a long time, I was not convinced that this is a core enough issue, and I am still not convinced.

The theoretical argument

The first "convincing" justification I came up with for RL vs supervied learning in LLMs is that supervised learning allows only positive feedback (we show the model a series of questions and their correct answers) while RL allows also for negative feedback (the model is allowed to generate an answer an get a feedback saying "this is not correct"). From a formal learning theory perspective, there is a big difference between the two: negative feedback is much more powerful. The theoretical argument is, roughly, that when learning only from demonstrations, an adversarial (or neglient..) demonstrator can mislead the learner into learning the wrong hypothesis by witholding some important examples. The demonstrator controls the learning process entirely. However, if you as a learner are allowed to form your own hypotheses and ask the teacher if they are correct (as in the RL setting), even an adversarial teacher can no longer trick you into latching on to a wrong hypothesis. It must disclose that its wrong if you ask about it. The learner is now much more powerful. (Of course, this assumes the adversarial or neglient teacher still plays by the rules and always provides truthful answers. But this is a reasonable assumption to make in a theoretical framework, and it does not hurt the overall argument of why learning from demonstrations is weaker than learning by interaction or by asking questions.)

This is all nice and well, and I do believe this is part of the reason RL is needed. But there is also an additional argument which might be even more important in the context of training large language models to communicate by answering questions.

The core argument

This leads me to the core reason that requires RL-like training. The previous two arguments rely on hypotheses such as "it might be harder for the model to learn" or "a neglient demonstrator may confuse the model", which may or may not hold in practice. In contrast, the current argument provably holds.

There are (at least) three modes of interaction with a language model: (a) text-grounded: we provide the model with a text and an instruction ("summarize this text", "based on this text, what is the population of Israel", "what are the chemical names mentioned in this text", "translate this text to spanish", etc), and expect the answer to be fully grounded in the provided text. (b) knowledge-seeking: we provide the model with a question or instruction, and expect a (truthful) answer based on the model's internal knowledge ("What are common causes of flu"). (c) creative: we provide the model with a question or instruction, and expect some creative output. ("Write a story about...")

The argument for RL is based on interaction type (b): knowledge-seeking queries in which we expect a truthful (or confident) answer, and the ability of the model to say "I don't know" or refuse to answer in situations in which it is uncertain.

For this type of interaction, we must use RL training, as supervised training teaches the model to lie. The core issue is that we want to encourage the model to answer based on its internal knowledge, but we don't know what this internal knowledge contains. In supervised training, we present the model with a question and its correct answer, and train the model to replicate the provided answer. There are two cases: (1) the model "knows" the answer. In this case, the supervised training correctly pushes it to associate the answer with the question, hopefully pushing it to perform similar steps to answer similar questions in the future. This is the desired behavior. (2) the model does not know the answer. In this case, the supervised training pushes the model to associate the answer with the question anyhow. Now, there are two options. It may push the model to memorize this particular question-answer pair. This is not harmful, but also not very effective, as our aim is for the model to generalize and learn to answer any question, not only the ones in the instructions training data. We want the model to generalize. But if we are succeed in training the model to generalize in these cases, then we essentially teaches the model to make stuff up! it actively encourages the model to "lie". This is bad.

Because we don't know what the model knows or not, we cannot avoid case (2), which is a real and serious issue for supervised training. We cannot use pure supervised learning to push the model for producing truthful answers, and we thus must use RL for this. In contrast to the supervised setting, the RL setting does not actively encourage the model to lie: even if the model does initally guess some answers correctly and learns a "making stuff up" behavior by mistake, in the long run it will get bad scores for made up answers (which are likely to be incorrect) and learn to adopt a policy that relies on its internal knowledge, or abstain.

Smaller remark: teaching to abstain

In case the model doesn't know the answer, we would like it to abstain and respond with "I don't know" or a similar answer. This is non trivial to do. This is hard to do in the supervised setting, because we do not know what the model knows or not. We can push it towards not answering questions of a certain kind ("never answer questions about people") and responding instead with "I don't know". But this is not the intended behavior of abstaining when the answer is unknown, only a very weak proxy for it. However, this is challenging also for the RL setting: the model may never produce an "I don't know" answer to begin with, and so we would have no way of encouraging it to generate such answers. One way around this is to start with some supervised training learning to produce "I don't know" answers in some cases, and then continuing the process with RL. In both the supervised and the RL cases there is the worry that the model will learn to over-generate "I don't know". This is an open research question. One possible family of solutions is to tailoring a reward that will assign very high scores to correct answers, medium-low scores to abstaining, and strong negative scores to incorrect answers. But this is not easy to get right.

Implications on model stealing / distillation

OpenAI, the company behind the GPT models, has invested a lot of effort in RL-type tuning of its language models. Parts of their motivation was to ensure factuality / truthfulness, by encouraging the model from abstaining from providing answers when it does not know the answer.

There is a recent trend of taking other, publicly available, base language models, and training them on GPT examples of GPT outputs, in order to replicate the GPT model's impressive behaviors.

Note that this is akin to supersied training / instruct tuning: the models are trained to produce the GPT model answers exactly. This should work well for teaching the model to perform instructions. However, it will not work well for case (b), teaching the model to answer knowledge-seeking queries. The publicly available base model likely knows a different set of facts from the OpenAI model, and training to replicate GPT's answers will suffer from the same issue supervised training suffers from: the model will be encouraged to make up facts to these types of queries, and additionally may learn to abstain in cases where it does know the answer but the GPT model did not.

The solution is, then, to train these models with RL. But isn't it too expensive?

Towards RL without human feedback

For a long time, training generative language tasks with RL has been impractical for most players: lacking a reliable automatic scoring metric, RL training requires a human feedback for every training sample. This is both expensive and extremely slow, especially for models that neeed to see thousands to tens or even hundreds of thousands of examples to learn.

However, RL training now becomes practical: first, it seems that large pre-trained language models manage to somehow learn from fewer examples. But, more importantly, they pave the way towards removing humans from the RL loop.

This relies on the observation that for text-grounded tasks the supervised training paradigm is very effective, and that the large models can learn to perform some tasks very well. One such task is considering two texts and asking "do these two texts mean the same thing", another is "are there facts in text A that do not appear in text B". (We can also decompose and task the model with "Generate all question-answer pairs that are answerable from this text" and then for each question ask "Is there an answer for this question in this other text, and what is it").

Empirically, large language models (or even medium ones) seem to be able to learn to perform such tasks rather reliably using supervised learning. This provides us with an effective automatic scoring metric that we can use in an RL setting: train on the human provided instruction-responses pairs, but rather than trying to replicate the human responses directly, let the model generate its own response, and compare the model generated response to the human provided one using a dedicated text comparison model that was trained in a supervised fashion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment