Created
January 23, 2024 23:05
-
-
Save tuetschek/da155533c0dde8c6b1916e9fcc7527b1 to your computer and use it in GitHub Desktop.
GPT2DoubleHeadsModel used for actual classification, not choice selection
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import transformers | |
import tqdm | |
import copy | |
import numpy as np | |
from logzero import logger | |
# some tiny data -- sentiment classification + LM | |
DATA = [ | |
[{'text': 'This is good . [CLS]', | |
'class': 1}, | |
{'text': 'This is bad . [CLS]', | |
'class': 0}], | |
[{'text': 'I liked it . [CLS]', | |
'class': 1}, | |
{'text': 'I hated it . [CLS]', | |
'class': 0}], | |
[{'text': 'It was great . [CLS]', | |
'class': 1}, | |
{'text': 'It was bad . [CLS]', | |
'class': 0},] | |
] | |
class GPT2DoubleHeadsSC(transformers.GPT2DoubleHeadsModel): | |
def __init__(self, config): | |
transformers.GPT2PreTrainedModel.__init__(self, config) | |
config.num_labels = 2 # XXX This is the only thing changed w.r.t. GPT2DoubleHeadsModel | |
self.transformer = transformers.GPT2Model(config) | |
self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
self.multiple_choice_head = transformers.modeling_utils.SequenceSummary(config) | |
# Model parallel | |
self.model_parallel = False | |
self.device_map = None | |
# Initialize weights and apply final processing | |
self.post_init() | |
class DataLoader: | |
def __init__(self, data, tokenizer): | |
self.data = [] | |
self.tokenizer = tokenizer | |
for batch in data: | |
tokenized = [self.tokenizer(i['text']) for i in batch] | |
self.data.append({'input_ids': torch.tensor([i['input_ids'] for i in tokenized]), | |
'labels': torch.tensor([i['input_ids'] for i in tokenized]), | |
'attention_mask': torch.tensor([i['attention_mask'] for i in tokenized]), | |
'mc_token_ids': torch.tensor([i['input_ids'].index(tokenizer.cls_token_id) for i in tokenized]), | |
'mc_labels': torch.tensor([i['class'] for i in batch])}) | |
def __iter__(self): | |
for batch in self.data: | |
yield copy.copy(batch) | |
def __len__(self): | |
return len(self.data) | |
class Trainer: | |
def __init__(self, | |
model, | |
train_data_loader, | |
epochs: int, | |
optimizer, | |
scheduler, | |
logger=logger): | |
self.model = model | |
self.device = model.device | |
self.train_data_loader = train_data_loader | |
self.epochs = epochs | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
self.logger = logger | |
def train(self): | |
self.logger.info('Starting training...') | |
for epoch in range(self.epochs): | |
self.logger.info(f'====== Epoch {epoch}/{self.epochs} Training ======') | |
self.model.train() | |
ep_loss = 0 | |
for step, batch in enumerate(tqdm.tqdm(self.train_data_loader)): | |
output = self.model(**batch) | |
# Backpropagate loss | |
loss = output.loss + output.mc_loss | |
ep_loss += loss.item() | |
loss.backward() | |
# Optimizer and scheduler steps | |
self.optimizer.step() | |
self.scheduler.step() | |
self.optimizer.zero_grad() | |
self.logger.debug(f'Epoch loss: {loss}') | |
def test_training(): | |
transformers.set_seed(42) | |
tokenizer = transformers.GPT2Tokenizer.from_pretrained("distilgpt2") | |
tokenizer.add_special_tokens({"cls_token": "[CLS]"}) | |
model = GPT2DoubleHeadsSC.from_pretrained('distilgpt2') | |
model.resize_token_embeddings(len(tokenizer)) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) | |
scheduler = transformers.get_constant_schedule(optimizer) | |
loader = DataLoader(DATA, tokenizer) | |
trainer = Trainer(model, loader, 20, optimizer, scheduler) | |
# overfit the model on this data | |
trainer.train() | |
# testing that the model is really overfit | |
model.eval() | |
toks_total, toks_corr, cls_total, cls_corr = 0, 0, 0, 0 | |
for batch in loader: | |
with torch.no_grad(): | |
output = model(**{'attention_mask': batch['attention_mask'], 'input_ids': batch['input_ids']}) | |
toks_preds = batch['input_ids'].numpy()[:, 1:] == torch.argmax(output.logits, dim=-1).numpy()[:, :-1] | |
cls_preds = batch['mc_labels'].numpy() == torch.argmax(output.mc_logits, dim=-1).numpy() | |
toks_corr += np.sum(toks_preds) | |
toks_total += np.prod(toks_preds.shape) | |
cls_corr += np.sum(cls_preds) | |
cls_total += np.prod(cls_preds.shape) | |
logger.info(f'Token accuracy: {toks_corr / toks_total}') | |
logger.info(f'Classification accuracy: {cls_corr / cls_total}') | |
if __name__ == '__main__': | |
test_training() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment