Skip to content

Instantly share code, notes, and snippets.

@reyoung
Created September 12, 2023 06:27
Show Gist options
  • Save reyoung/07dd50bbed90a6b31c1e04c0c7779de3 to your computer and use it in GitHub Desktop.
Save reyoung/07dd50bbed90a6b31c1e04c0c7779de3 to your computer and use it in GitHub Desktop.
from itertools import chain
from typing import Optional, Tuple
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Model, default_data_collator
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
import datasets
from torch import nn
import torch.nn.functional as F
import torch
from torch.optim import Adam
class LMHead(nn.Module):
def __init__(self, hidden_size: int, vocab_size: int):
super().__init__()
self.layer = nn.Linear(hidden_size, vocab_size, bias=False)
self.bias = nn.Parameter(requires_grad=False,
# 0.1 means > 0.9 probability for the most probable token
data=torch.tensor([0.1], dtype=torch.float32))
def forward(self, x):
res = self.layer(x)
res = F.log_softmax(res, dim=-1)
values, _ = torch.max(res, dim=-1)
gate = F.relu(values + self.bias)
return res, gate
class GPT2BlockWithLMHead(nn.Module):
def __init__(self, block: GPT2Block, vocab_size: int):
super().__init__()
self.head = LMHead(block.ln_1.weight.shape[0], vocab_size=vocab_size)
self.block = block
self.res = {}
self.with_lm_head = True
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False, ):
if self.with_lm_head:
res, gate = self.head(hidden_states)
self.res["log_softmax"] = res
self.res["gate"] = gate
attention_mask += (gate * -100000).view(attention_mask.shape)
return self.block(hidden_states=hidden_states, layer_past=layer_past, attention_mask=attention_mask,
head_mask=head_mask, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, use_cache=use_cache,
output_attentions=output_attentions)
def main():
ds = datasets.load_from_disk("wikitext.local")
train_dataloader = DataLoader(
ds["train"], shuffle=True, collate_fn=default_data_collator, batch_size=2
)
gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel.from_pretrained('gpt2')
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
for p in gpt2_model.parameters():
p.requires_grad = False
block: GPT2Block = gpt2_model.base_model.h[-1]
block_with_lm_head = GPT2BlockWithLMHead(block, vocab_size=gpt2_tokenizer.vocab_size)
gpt2_model.base_model.h = nn.ModuleList(gpt2_model.base_model.h[:-1] + [block_with_lm_head])
gpt2_model = gpt2_model.to("cuda")
optim = Adam(block_with_lm_head.parameters())
for batch in train_dataloader:
batch = {k: v.to("cuda") for k, v in batch.items()}
optim.zero_grad()
block_with_lm_head.with_lm_head = True
res1 = gpt2_model(**batch)
loss_fn = nn.NLLLoss()
log_softmax = block_with_lm_head.res["log_softmax"]
loss = loss_fn(log_softmax.reshape(-1, gpt2_tokenizer.vocab_size), batch["labels"].reshape(-1))
max_values, _ = torch.max(log_softmax, dim=-1)
block_with_lm_head.with_lm_head = False
with torch.no_grad():
res2 = gpt2_model(**batch)
loss_diff = torch.nn.functional.relu(res1["loss"] - res2["loss"])
part2 = torch.nn.functional.relu(loss_diff) * 100
total_loss = loss + part2
total_loss.backward()
print(loss.cpu().item(), loss_diff.cpu().item(), part2.cpu().item())
optim.step()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment