Last active
July 7, 2025 15:43
-
-
Save vwxyzjn/6f9653a546c53fd536e441568129ccd9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# set seed | |
torch.manual_seed(42) | |
B, T, D = 1, 4, 10 # batch size, sequence length, vocab size | |
tensor = torch.rand(B, T, D, requires_grad=True) | |
labels = torch.Tensor([[1, 2, 3, 4]]).long() | |
print("="*100) | |
print("baseline gradient") | |
loss = torch.nn.functional.cross_entropy( | |
tensor.view(-1, D), labels.view(-1), reduction="none", ignore_index=-100) | |
token_count = torch.count_nonzero(labels != -100) | |
loss = loss.sum() / token_count | |
loss.backward() | |
print(f"gradient: {tensor.grad}") | |
print("="*100) | |
print("gradient with coefficient") | |
tensor.grad.zero_() | |
loss = torch.nn.functional.cross_entropy( | |
tensor.view(-1, D), labels.view(-1), reduction="none", ignore_index=-100) | |
# for example, here we assume | |
# token 0 is a good action | |
# token 1 is a bad action | |
# token 2 is a good action | |
# token 3 is a bad action | |
# I'd like to see positive gradient for token 0 and token 2, and negative gradient for token 1 and token 3 | |
coefficient = torch.Tensor([[1, -1, 1, -1]]) | |
loss = loss * coefficient | |
print(f"loss: {loss}") | |
print(f"4 token loss mean: {loss[:4].mean()=}") | |
token_count = torch.count_nonzero(labels != -100) | |
loss = loss.sum() / token_count | |
loss.backward() | |
print(f"gradient: {tensor.grad}") | |
print("="*100) | |
print("gradient with log_softmax_and_gather") | |
def log_softmax_and_gather(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor: | |
""" | |
torch compiled version of the common `log_softmax -> gather` operation. | |
The compiled version of this opration avoids the (significant) memory overhead of | |
allocating a new (batch_size, seq_len, vocab_size) tensor to store the logprobs. | |
See https://github.com/allenai/open-instruct/pull/584 | |
""" | |
logprobs = logits.log_softmax(dim=-1) | |
return torch.gather(logprobs, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) | |
tensor.grad.zero_() | |
logprobs = -log_softmax_and_gather(tensor.view(-1, D), labels.view(-1)) | |
loss = logprobs * coefficient | |
print(f"loss: {loss}") | |
loss = loss.sum() / token_count | |
loss.backward() | |
print(f"gradient: {tensor.grad}") | |
# ==================================================================================================== | |
# baseline gradient | |
# gradient: tensor([[[ 0.0310, -0.2180, 0.0188, 0.0335, 0.0189, 0.0234, 0.0166, | |
# 0.0284, 0.0328, 0.0146], | |
# [ 0.0326, 0.0232, -0.2194, 0.0226, 0.0269, 0.0197, 0.0310, | |
# 0.0227, 0.0167, 0.0240], | |
# [ 0.0229, 0.0272, 0.0235, -0.2099, 0.0194, 0.0229, 0.0250, | |
# 0.0213, 0.0302, 0.0176], | |
# [ 0.0342, 0.0142, 0.0320, 0.0237, -0.2315, 0.0297, 0.0235, | |
# 0.0326, 0.0230, 0.0186]]]) | |
# ==================================================================================================== | |
# gradient with coefficient | |
# loss: tensor([[ 2.0554, -2.1020, 1.8296, -2.6033]], grad_fn=<MulBackward0>) | |
# 4 token loss mean: loss[:4].mean()=tensor(-0.2051, grad_fn=<MeanBackward0>) | |
# gradient: tensor([[[ 0.0310, -0.2180, 0.0188, 0.0335, 0.0189, 0.0234, 0.0166, | |
# 0.0284, 0.0328, 0.0146], | |
# [-0.0326, -0.0232, 0.2194, -0.0226, -0.0269, -0.0197, -0.0310, | |
# -0.0227, -0.0167, -0.0240], | |
# [ 0.0229, 0.0272, 0.0235, -0.2099, 0.0194, 0.0229, 0.0250, | |
# 0.0213, 0.0302, 0.0176], | |
# [-0.0342, -0.0142, -0.0320, -0.0237, 0.2315, -0.0297, -0.0235, | |
# -0.0326, -0.0230, -0.0186]]]) | |
# ==================================================================================================== | |
# gradient with log_softmax_and_gather | |
# loss: tensor([[ 2.0554, -2.1020, 1.8296, -2.6033]], grad_fn=<MulBackward0>) | |
# gradient: tensor([[[ 0.0310, -0.2180, 0.0188, 0.0335, 0.0189, 0.0234, 0.0166, | |
# 0.0284, 0.0328, 0.0146], | |
# [-0.0326, -0.0232, 0.2194, -0.0226, -0.0269, -0.0197, -0.0310, | |
# -0.0227, -0.0167, -0.0240], | |
# [ 0.0229, 0.0272, 0.0235, -0.2099, 0.0194, 0.0229, 0.0250, | |
# 0.0213, 0.0302, 0.0176], | |
# [-0.0342, -0.0142, -0.0320, -0.0237, 0.2315, -0.0297, -0.0235, | |
# -0.0326, -0.0230, -0.0186]]]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment