Last active
September 6, 2022 19:12
-
-
Save ConradStack/a18cab83917bbe111328283c10892cf2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Derived from https://towardsdatascience.com/how-to-fine-tune-gpt-2-for-text-generation-ae2ea53bc272 | |
import os | |
import pandas as pd | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import numpy as np | |
import random | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup | |
from tqdm import tqdm, trange | |
import torch.nn.functional as F | |
import csv | |
import re | |
tok_delim = re.compile(r'\s+') | |
SONG_COLS = ['Artist', 'SName'] | |
# ----- Data Prep ----- | |
### Prepare data | |
lyrics = pd.read_csv('data/lyrics-data.csv') | |
lyrics = lyrics[lyrics.language=='en'] | |
artists = pd.read_csv('data/artists-data.csv') | |
artists.loc[:,"Genres"] = artists.Genres.str.split(";") | |
artists = artists.explode("Genres") | |
artists.loc[:,"Genres"] = artists.Genres.str.strip() | |
### Only keep popular artists, with genre Rock/Pop and popularity high enough | |
artists = artists[(artists['Genres'].isin(['Rock', 'Pop'])) & (artists['Popularity']>5)] | |
### Drop duplicated artist rows (keeping 'Rock' over 'Pop') | |
artists.sort_values('Genres', ascending=False, inplace=True) | |
artists.drop_duplicates( subset = list(set(artists.columns) - set(['Genres'])), inplace=True, keep='first' ) | |
### Join lyrics, artists | |
df = lyrics.merge(artists[['Artist', 'Genres', 'Link']], left_on='ALink', right_on='Link', how='inner') | |
df.drop(columns=['ALink','SLink','Genres','Link'], inplace=True) | |
### Tokenize lyric text, add columns to df | |
tmp = df.Lyric.str.split(tok_delim) | |
def notempty(y): return(len(y) > 0) | |
tmp = tmp.apply( lambda x: list(filter(notempty, x))) | |
lyric_nwords = tmp.apply(len) | |
df.insert(df.shape[1], 'lyric_nwords', lyric_nwords ) | |
df.insert(df.shape[1], 'lyric_tokens', tmp ) | |
### ... overwrite original lyric strings with simplified versions | |
df.loc[:, "Lyric"] = df.lyric_tokens.apply(' '.join) | |
### filter out songs with too few (<25) or too many words (>350) | |
df = df[ (lyric_nwords>=25) & (lyric_nwords < 350) ].reset_index(drop=True) | |
del lyric_nwords, tmp | |
### Create a very small test set to compare generated text with the reality | |
test_set = df.sample(n = 200, random_state = 106) | |
train_set = df.drop( index=test_set.index ).copy() | |
test_set.reset_index(drop=True, inplace=True) | |
train_set.reset_index(drop=True, inplace=True) | |
### sanity checks | |
### ... row counts | |
assert df.shape[0] == (train_set.shape[0] + test_set.shape[0]) | |
### ... confirm no overlapping songs | |
shared_songs = train_set.loc[:,SONG_COLS].merge(test_set.loc[:,SONG_COLS], how='inner') | |
assert shared_songs.shape[0] == 0, "ERROR: overlapping songs in test, train sets" | |
### For the test set only, keep last 20 words in a new column, then remove them from original column | |
test_set.insert( test_set.shape[1], 'True_end_lyrics', test_set.lyric_tokens.str[-20:].apply(' '.join) ) | |
test_set.loc[:,'Lyric'] = test_set.lyric_tokens.str[:-20].apply(' '.join) | |
class SongLyrics(Dataset): | |
def __init__(self, lyrics : pd.Series, gpt2_type = "gpt2", max_length=1022, truncate=0, **kwargs): | |
self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type, **kwargs) | |
self.lyrics = [] | |
for i,text in lyrics.iteritems(): | |
if (truncate > 0) and (i == truncate): | |
break | |
lyric_toks = self.tokenizer.tokenize(text) | |
if len(lyric_toks) > max_length: | |
istart = np.random.randint( (len(lyric_toks) - max_length) ) | |
lyric_toks = lyric_toks[istart:(istart+max_length)] | |
self.lyrics.append( torch.tensor([ | |
self.tokenizer.bos_token_id, | |
*self.tokenizer.convert_tokens_to_ids(lyric_toks), | |
self.tokenizer.eos_token_id | |
])) | |
self.lyrics_count = len(self.lyrics) | |
def __len__(self): | |
return self.lyrics_count | |
def __getitem__(self, idx): | |
return idx, self.lyrics[idx] | |
dataset = SongLyrics(train_set.Lyric, gpt2_type="gpt2") | |
# Get the tokenizer and model | |
tokenizer = dataset.tokenizer | |
model = GPT2LMHeadModel.from_pretrained('gpt2') | |
# Function that helps combine (encoded) lyric data into mini-batches, dynamically | |
def pack_tensor(new_tensor, packed_tensor, max_seq_len): | |
if packed_tensor is None: | |
return new_tensor, True, None | |
if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len: | |
return packed_tensor, False, new_tensor | |
else: | |
packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1) | |
return packed_tensor, True, None | |
# ----- Train ----- | |
def train( | |
dataset, model, tokenizer, | |
batch_size=16, epochs=5, lr=2e-5, | |
max_seq_len=768, warmup_steps=200, | |
output_dir="_scratch", | |
output_prefix="lyric_gpt2demo", | |
test_mode=False, | |
save_model_on_epoch=False, | |
): | |
device = torch.device("cuda:0") | |
model = model.to(device) | |
model.train() | |
optimizer = AdamW(model.parameters(), lr=lr) | |
scheduler = get_linear_schedule_with_warmup( | |
optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1 | |
) | |
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True) | |
for epoch in range(epochs): | |
optimizer.zero_grad() | |
loss=0 | |
input_tensor = None | |
### DEBUG/ | |
# (optional) vector of loss values per minibatch | |
losses=[] | |
# (optional) vector of flags indicating the minibatches where `optimizer.zero_grad` was called | |
accumulate = torch.zeros(len(train_dataloader), dtype=torch.bool) | |
### /DEBUG | |
print(f"Training epoch {epoch}") | |
for batch_idx, (idx, entry) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), mininterval=15, maxinterval=60, miniters=200, leave=False): | |
(input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, max_seq_len ) | |
if carry_on and ((batch_idx+1) != len(train_dataloader)): | |
continue | |
input_tensor = input_tensor.to(device) | |
outputs = model(input_tensor, labels=input_tensor) | |
loss = outputs[0] | |
loss.backward() | |
if (((batch_idx+1) % batch_size) == 0) or ((batch_idx+1) == len(train_dataloader)): | |
optimizer.step() | |
scheduler.step() | |
optimizer.zero_grad() | |
accumulate[batch_idx] = 1 | |
#input_tensor = None | |
input_tensor = remainder | |
losses.append( loss.detach().cpu().item() ) | |
print(f"avg loss: {np.mean(losses)} for epoch {epoch}") | |
if save_model_on_epoch: | |
print('saving epoch state') | |
torch.save({ | |
"epoch" : epoch, | |
"accum_batches" : batch_size, | |
"lr" : lr, | |
"max_seq_len" : max_seq_len, | |
"state_dict" : model.state_dict(), | |
"losses" : losses, | |
'accumulate' : accumulate | |
}, | |
os.path.join(output_dir, f"{output_prefix}-{epoch}.torch"), | |
) | |
return model | |
model = train(dataset, model, tokenizer, save_model_on_epoch=True) | |
## (optional) Save fine-tuned model | |
# torch.save( model.state_dict(), '_scratch/gpt2demo_finetuned.STATE_DICT.torch' ) | |
# model.save_pretrained( '_scratch/lyrics_gpt2demo/model' ) | |
# tokenizer.save_pretrained( '_scratch/lyrics_gpt2demo/tokenizer' ) | |
# ----- Generate ----- | |
def generate( | |
model, | |
tokenizer, | |
prompt, | |
entry_length=30, #maximum number of tokens to generate | |
top_p=0.8, | |
temperature=1., | |
): | |
model.eval() | |
#generated_list = [] | |
filter_value = -float("Inf") | |
# with torch.no_grad(): | |
with torch.inference_mode(): | |
entry_finished = False | |
generated = torch.tensor(tokenizer.encode(prompt), device=model.device).unsqueeze(0) | |
nstart = generated.shape[-1] | |
for nth in range(entry_length): | |
# outputs = model(generated, labels=generated) | |
loss, logits, __ = model(generated, labels=generated).to_tuple() | |
logits = logits[:, -1, :] / temperature | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[:, indices_to_remove] = filter_value | |
next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) | |
generated = torch.cat((generated, next_token), dim=1) | |
## Flag whether or not the next token is the end-to-string special token | |
entry_finished = (next_token.item() == tokenizer.eos_token_id) | |
## stop early if end-of-sequence token is reached: | |
if entry_finished: break | |
ngenerated = (generated.shape[-1] - nstart) | |
assert ngenerated == (nth+1), "sanity check failed; check loop" | |
output_list = list(generated.cpu().squeeze().numpy()) | |
#output_text = f"{tokenizer.decode(output_list)}{'' if entry_finished else '<|endoftext|>'}" | |
### only return the new (generated) text: | |
generated_list = output_list[-ngenerated:] | |
generated_text = f"{tokenizer.decode(generated_list)}{'' if entry_finished else tokenizer.eos_token}" | |
return generated_text | |
# generate lyrics for test_set | |
generated_lyrics = ['']*test_set.shape[0] | |
for i in trange(test_set.shape[0], leave=False): | |
generated_lyrics[i] = generate(model, tokenizer, test_set.Lyric.iloc[i]) | |
test_set.insert( test_set.shape[1], 'Generated_lyrics', generated_lyrics ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is a work-in-progress! If using this expect to find / fix bugs