Last active
June 30, 2023 19:12
-
-
Save crowsonkb/7a9d6a852e47b4f8026947a08a47774c to your computer and use it in GitHub Desktop.
REINFORCE with exponential moving average baseline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely | |
Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098).""" | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from typing import Optional, Union | |
class Reinforce(nn.Module): | |
"""REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely | |
Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098). | |
Args: | |
use_baseline (bool): Subtract baseline from losses. Defaults to True. | |
beta (float): Exponential moving average decay factor for baseline. | |
Example: | |
>>> opt = optim.Adam(model.parameters(), lr=1e-3) | |
>>> estimator = Reinforce().to(device) | |
In your training loop: | |
>>> opt.zero_grad() | |
>>> actions = estimator.sample_categorical(logits) | |
Then, after you have computed a batch of losses: | |
>>> loss = estimator.prepare_losses(losses) | |
>>> loss.backward() | |
>>> opt.step() | |
""" | |
def __init__(self, use_baseline: bool = True, beta: float = 0.99): | |
super().__init__() | |
self.use_baseline = use_baseline | |
self.beta = beta | |
self.register_buffer("beta_cumprod", torch.tensor(1.0)) | |
self.register_buffer("loss_mean_biased", torch.tensor(0.0)) | |
self.logprobs = [] | |
@staticmethod | |
def magic_box(w: torch.Tensor) -> torch.Tensor: | |
"""MagicBox operator (see https://arxiv.org/abs/1802.05098). | |
Args: | |
w (torch.Tensor): Input tensor. | |
Returns: | |
torch.Tensor: The result of the MagicBox operator. | |
""" | |
return torch.exp(w - w.detach()) | |
def register_actions(self, logprobs: torch.Tensor, mask: Optional[torch.Tensor]) -> None: | |
"""Register logprobs of actions to attach their grad path before the backward pass. | |
Args: | |
logprobs (torch.Tensor): Logprobs of actions. | |
mask (torch.Tensor, optional): Mask for actions. Defaults to None. | |
""" | |
if mask is not None: | |
logprobs = logprobs * mask | |
self.logprobs.append(logprobs) | |
def prepare_losses( | |
self, losses: torch.Tensor, baseline: Optional[Union[float, torch.Tensor]] = None | |
) -> torch.Tensor: | |
"""Prepare a batch of losses for the backward pass. | |
Args: | |
losses (torch.Tensor): Batch of losses to prepare. | |
baseline (Optional[Union[float, torch.Tensor]], optional): Custom baseline to subtract. | |
Returns: | |
torch.Tensor: Prepared loss. | |
""" | |
loss_mean = self.loss_mean_biased / (1 - self.beta_cumprod) | |
loss_mean.nan_to_num_() | |
self.beta_cumprod.mul_(self.beta) | |
self.loss_mean_biased.mul_(self.beta).add_(losses.detach().mean(), alpha=1 - self.beta) | |
if baseline is not None: | |
pass | |
elif self.use_baseline: | |
baseline = loss_mean | |
else: | |
baseline = 0.0 | |
logprobs = [logprobs.flatten(losses.ndim).sum(losses.ndim) for logprobs in self.logprobs] | |
logprobs = sum(logprobs, torch.zeros_like(losses)) | |
self.logprobs.clear() | |
surrogates = losses * self.magic_box(logprobs) + (1 - self.magic_box(logprobs)) * baseline | |
return surrogates.mean() | |
def sample_categorical( | |
self, | |
logits: torch.Tensor, | |
actions: Optional[torch.Tensor] = None, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
"""Sample from a categorical distribution and register the actions' grad paths for the | |
backward pass. If actions are provided, register them instead of sampling. | |
Args: | |
logits (torch.Tensor): Unnormalized logits of categorical distribution. | |
actions (torch.Tensor, optional): Actions that were taken. Defaults to None. | |
mask (torch.Tensor, optional): Mask for tokens. Defaults to None. | |
Returns: | |
torch.Tensor: Actions that were taken. | |
Example: | |
If you have sampled tokens from a HuggingFace model, you can use this method to | |
register the grad paths of the sampled tokens. You need to obtain logits from the | |
model that have a grad_fn: | |
>>> logits = model(tokens).logits | |
>>> estimator.sample_categorical(logits[:, prompt_len - 1 : -1], tokens[:, prompt_len:]) | |
Notice how the tokens are shifted one position right from the logits they were sampled | |
from and the prompt tokens aren't included. If you cannot exclude your prompt or | |
padding tokens with simple slicing, you can provide a mask (1/True for token positions | |
that grads should propagate through, 0/False to stop gradients). | |
""" | |
if actions is None: | |
g = torch.rand_like(logits).log_().neg_().log_().neg_() | |
actions = torch.argmax(logits + g, dim=-1) | |
logprobs = F.log_softmax(logits, dim=-1).gather(-1, actions[..., None]) | |
self.register_actions(logprobs, mask) | |
return actions |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment