Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Last active February 21, 2025 04:35
Show Gist options
  • Save vwxyzjn/796329e3083407a9c068c4e6dd76c40b to your computer and use it in GitHub Desktop.
Save vwxyzjn/796329e3083407a9c068c4e6dd76c40b to your computer and use it in GitHub Desktop.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
model.to(device)
p = tokenizer.pad_token_id
def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor:
row_len = bools.size(-1)
zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
return torch.min(zero_or_index, dim=-1).values
# assume we have
# | prompt, prompt, prompt, | response, response, pad |
# | pad, pad, prompt | response, response, response |
queries = torch.tensor([
[32, 33, 34],
[p, p, 34],
]).to(device)
responses = torch.tensor([
[35, 36, p],
[35, 36, 38],
]).to(device)
query_responses = torch.cat([queries, responses], dim=1)
attention_mask = query_responses != p
position_ids = attention_mask.cumsum(1) - attention_mask.long()
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
)
context_length = 3
logits = output.logits[:, context_length - 1 : -1]
all_logprobs = F.log_softmax(logits, dim=-1)
# mask out pad token
response_lengths = first_true_indices(responses == p) - 1
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > response_lengths.unsqueeze(1)
masked_response = torch.where(~padding_mask, responses, 0)
logprobs = torch.gather(all_logprobs, 2, masked_response.unsqueeze(-1)).squeeze(-1)
logprobs = torch.masked_fill(logprobs, padding_mask, 1.0) # INVALID_LOGPROB = 1.0, but actually doesn't matter
print(logprobs)
# reinforce loss
returns = torch.tensor([0.0, 1.0])
loss = -(logprobs.sum(1) * returns).mean() # assume each token is independent
print("reinforce loss", loss)
loss.backward()
print("reinforce loss grad", model.lm_head.weight.grad.sum())
# reset the gradient
model.zero_grad()
# @y0b1byte's style SFT loss
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
)
context_length = 3
# for making a fair comparison, we only apply the CE loss on the response tokens
logits = output.logits[:, context_length - 1 : -1]
loss = F.cross_entropy(logits[1:][0], responses[1:][0], ignore_index=p, reduction="mean")
print("@y0b1byte's style SFT loss", loss)
loss.backward()
print("@y0b1byte's style SFT loss grad", model.lm_head.weight.grad.sum())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment