Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Last active July 7, 2025 15:43
Show Gist options
  • Save vwxyzjn/6f9653a546c53fd536e441568129ccd9 to your computer and use it in GitHub Desktop.
Save vwxyzjn/6f9653a546c53fd536e441568129ccd9 to your computer and use it in GitHub Desktop.
import torch
# set seed
torch.manual_seed(42)
B, T, D = 1, 4, 10 # batch size, sequence length, vocab size
tensor = torch.rand(B, T, D, requires_grad=True)
labels = torch.Tensor([[1, 2, 3, 4]]).long()
print("="*100)
print("baseline gradient")
loss = torch.nn.functional.cross_entropy(
tensor.view(-1, D), labels.view(-1), reduction="none", ignore_index=-100)
token_count = torch.count_nonzero(labels != -100)
loss = loss.sum() / token_count
loss.backward()
print(f"gradient: {tensor.grad}")
print("="*100)
print("gradient with coefficient")
tensor.grad.zero_()
loss = torch.nn.functional.cross_entropy(
tensor.view(-1, D), labels.view(-1), reduction="none", ignore_index=-100)
# for example, here we assume
# token 0 is a good action
# token 1 is a bad action
# token 2 is a good action
# token 3 is a bad action
# I'd like to see positive gradient for token 0 and token 2, and negative gradient for token 1 and token 3
coefficient = torch.Tensor([[1, -1, 1, -1]])
loss = loss * coefficient
print(f"loss: {loss}")
print(f"4 token loss mean: {loss[:4].mean()=}")
token_count = torch.count_nonzero(labels != -100)
loss = loss.sum() / token_count
loss.backward()
print(f"gradient: {tensor.grad}")
print("="*100)
print("gradient with log_softmax_and_gather")
def log_softmax_and_gather(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""
torch compiled version of the common `log_softmax -> gather` operation.
The compiled version of this opration avoids the (significant) memory overhead of
allocating a new (batch_size, seq_len, vocab_size) tensor to store the logprobs.
See https://github.com/allenai/open-instruct/pull/584
"""
logprobs = logits.log_softmax(dim=-1)
return torch.gather(logprobs, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
tensor.grad.zero_()
logprobs = -log_softmax_and_gather(tensor.view(-1, D), labels.view(-1))
loss = logprobs * coefficient
print(f"loss: {loss}")
loss = loss.sum() / token_count
loss.backward()
print(f"gradient: {tensor.grad}")
# ====================================================================================================
# baseline gradient
# gradient: tensor([[[ 0.0310, -0.2180, 0.0188, 0.0335, 0.0189, 0.0234, 0.0166,
# 0.0284, 0.0328, 0.0146],
# [ 0.0326, 0.0232, -0.2194, 0.0226, 0.0269, 0.0197, 0.0310,
# 0.0227, 0.0167, 0.0240],
# [ 0.0229, 0.0272, 0.0235, -0.2099, 0.0194, 0.0229, 0.0250,
# 0.0213, 0.0302, 0.0176],
# [ 0.0342, 0.0142, 0.0320, 0.0237, -0.2315, 0.0297, 0.0235,
# 0.0326, 0.0230, 0.0186]]])
# ====================================================================================================
# gradient with coefficient
# loss: tensor([[ 2.0554, -2.1020, 1.8296, -2.6033]], grad_fn=<MulBackward0>)
# 4 token loss mean: loss[:4].mean()=tensor(-0.2051, grad_fn=<MeanBackward0>)
# gradient: tensor([[[ 0.0310, -0.2180, 0.0188, 0.0335, 0.0189, 0.0234, 0.0166,
# 0.0284, 0.0328, 0.0146],
# [-0.0326, -0.0232, 0.2194, -0.0226, -0.0269, -0.0197, -0.0310,
# -0.0227, -0.0167, -0.0240],
# [ 0.0229, 0.0272, 0.0235, -0.2099, 0.0194, 0.0229, 0.0250,
# 0.0213, 0.0302, 0.0176],
# [-0.0342, -0.0142, -0.0320, -0.0237, 0.2315, -0.0297, -0.0235,
# -0.0326, -0.0230, -0.0186]]])
# ====================================================================================================
# gradient with log_softmax_and_gather
# loss: tensor([[ 2.0554, -2.1020, 1.8296, -2.6033]], grad_fn=<MulBackward0>)
# gradient: tensor([[[ 0.0310, -0.2180, 0.0188, 0.0335, 0.0189, 0.0234, 0.0166,
# 0.0284, 0.0328, 0.0146],
# [-0.0326, -0.0232, 0.2194, -0.0226, -0.0269, -0.0197, -0.0310,
# -0.0227, -0.0167, -0.0240],
# [ 0.0229, 0.0272, 0.0235, -0.2099, 0.0194, 0.0229, 0.0250,
# 0.0213, 0.0302, 0.0176],
# [-0.0342, -0.0142, -0.0320, -0.0237, 0.2315, -0.0297, -0.0235,
# -0.0326, -0.0230, -0.0186]]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment