Last active
April 30, 2023 03:58
-
-
Save ericflo/5b385303d589172d86512f0f38f810a3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import random | |
import textwrap | |
import re | |
import math | |
import torch | |
from torch import nn | |
from torch.utils.data import DataLoader, Dataset, IterableDataset | |
from torch.optim.adamw import AdamW | |
from torch.optim.sgd import SGD | |
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer | |
import datasets | |
from tqdm import tqdm | |
PRECISION = torch.bfloat16 | |
USERNAME = "exampleuser" | |
class LoRA(nn.Module): | |
def __init__(self, dim, dim_out, r=8, alpha=None): | |
super().__init__() | |
alpha = alpha if alpha is not None else r | |
self.scale = alpha / r | |
self.A = nn.Parameter( | |
nn.init.kaiming_uniform_(torch.randn(dim, r).to(PRECISION), a=math.sqrt(5)) | |
) | |
self.B = nn.Parameter(torch.zeros(r, dim_out).to(PRECISION)) | |
@property | |
def weight(self): | |
return (self.A @ self.B) * self.scale | |
def forward(self, x): | |
return x @ self.weight | |
class LoRAForward(nn.Module): | |
def __init__(self, original_layer, lora): | |
super().__init__() | |
self.original_layer = original_layer | |
self.lora = lora | |
def forward(self, x, *args, **kwargs): | |
output = self.original_layer(x, *args, **kwargs) | |
prev_output = output[0] | |
lora_output = self.lora(x).view(prev_output.shape) | |
merged_output = prev_output + lora_output | |
return merged_output, *output[1:] | |
class TextDataset(Dataset): | |
def __init__(self, texts, tokenizer, max_length=280): | |
self.texts = texts | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self): | |
return len(self.texts) | |
def __getitem__(self, idx): | |
input_text = self.texts[idx] | |
tokenized = self.tokenizer( | |
input_text, | |
truncation=True, | |
max_length=self.max_length, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
return tokenized | |
class ThePileDataset(IterableDataset): | |
def __init__(self, tokenizer, max_length=280, buffer_size=10_000, seed=42): | |
self.the_pile = datasets.load_dataset( | |
"EleutherAI/the_pile_deduplicated", | |
split="train", | |
streaming=True, | |
).shuffle(buffer_size=buffer_size, seed=seed) | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self): | |
# return len(self.the_pile) | |
return 134318121 | |
def __iter__(self): | |
for row in self.the_pile: | |
yield self.tokenizer( | |
row["text"], | |
truncation=True, | |
max_length=self.max_length, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
def insert_lora_layers(model, config, r=8, alpha=None): | |
for i in range(config.num_hidden_layers): | |
layer = model.gpt_neox.layers[i] | |
# Get input and output dimensions for the LoRA layer | |
dim_in = config.hidden_size | |
dim_out = config.hidden_size | |
lora = LoRA(dim=dim_in, dim_out=dim_out, r=r, alpha=alpha) | |
# Replace the existing layer with a LoRAForward container containing the original layer and the new LoRA layer | |
model.gpt_neox.layers[i] = LoRAForward(layer, lora) | |
def extract_lora_weights(model): | |
lora_weights = {} | |
for idx, layer in enumerate(model.gpt_neox.layers): | |
if isinstance(layer, LoRAForward): | |
lora_weights[f"lora_{idx}_A"] = layer.lora.A | |
lora_weights[f"lora_{idx}_B"] = layer.lora.B | |
return lora_weights | |
def load_lora_weights(model, lora_weights): | |
for idx, layer in enumerate(model.gpt_neox.layers): | |
if isinstance(layer, LoRAForward): | |
layer.lora.A = lora_weights[f"lora_{idx}_A"] | |
layer.lora.B = lora_weights[f"lora_{idx}_B"] | |
def freeze_model(model): | |
for param in model.parameters(): | |
param.requires_grad = False | |
def generate_sample( | |
model, | |
tokenizer, | |
prompt, | |
min_length=50, | |
max_length=280, | |
top_k=50, | |
top_p=0.95, | |
temperature=0.8, | |
): | |
with torch.no_grad(): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
generated = model.generate( | |
input_ids, | |
max_length=max_length, | |
min_length=min_length, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_k=top_k, | |
top_p=top_p, | |
temperature=temperature, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
resp = tokenizer.decode(generated[0], skip_special_tokens=True) | |
return resp | |
def print_demo_text(model, tokenizer): | |
temperatures = [0.7, 0.8, 0.9] | |
temperature = random.choice(temperatures) | |
temperatures.remove(temperature) | |
generated_text = generate_sample( | |
model, tokenizer, f"{USERNAME}: I just", temperature=temperature | |
) | |
print( | |
f"Generated text 1 [{temperature}]:\n{textwrap.indent(textwrap.fill(generated_text), ' ')}" | |
) | |
temperature = random.choice(temperatures) | |
temperatures.remove(temperature) | |
generated_text = generate_sample( | |
model, tokenizer, f"{USERNAME}: I think", temperature=temperature | |
) | |
print( | |
f"Generated text 2 [{temperature}]:\n{textwrap.indent(textwrap.fill(generated_text), ' ')}" | |
) | |
temperature = random.choice(temperatures) | |
temperatures.remove(temperature) | |
generated_text = generate_sample( | |
model, tokenizer, f"{USERNAME}: When it comes to", temperature=temperature | |
) | |
print( | |
f"Generated text 3 [{temperature}]:\n{textwrap.indent(textwrap.fill(generated_text), ' ')}" | |
) | |
print("-----") | |
if __name__ == "__main__": | |
# Load the pretrained model and config | |
model_name = "EleutherAI/pythia-1.4b-deduped" | |
# model_name = "EleutherAI/pythia-160m-deduped" | |
# model_name = "EleutherAI/pythia-70m-deduped" | |
config = AutoConfig.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, config=config, torch_dtype=PRECISION | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
should_train = True | |
# Freeze the pretrained model parameters | |
freeze_model(model) | |
# Insert LoRA layers | |
rank = 8 | |
lora_alpha = 16.0 | |
insert_lora_layers(model, config, r=rank, alpha=lora_alpha) | |
device = torch.device("cuda") | |
# device = torch.device("mps") | |
model = model.to(device) | |
batch_accumulation = 48 | |
# lr = 8e-4 | |
# lr = 0.0064 | |
lr = 2e-4 | |
# lr = 1e-4 | |
# lr = 5e-5 | |
# lr = 1e-5 | |
start_epoch, num_epochs = 0, 200 | |
batch_size = 20 | |
model_save_dir = ( | |
f"model_checkpoints_{re.sub('[^a-zA-Z0-9]+', '_', model_name.split('/')[-1])}" | |
) | |
os.makedirs(model_save_dir, exist_ok=True) | |
latest_model = ( | |
sorted(os.listdir(model_save_dir), reverse=True)[0] | |
if os.listdir(model_save_dir) | |
else None | |
) | |
if latest_model: | |
model_path = os.path.join(model_save_dir, latest_model) | |
start_epoch = ( | |
int(os.path.splitext(os.path.basename(model_path))[0].split("_")[-2]) + 1 | |
) | |
lora_weights = torch.load(model_path) | |
load_lora_weights(model, lora_weights) | |
print( | |
f"Loaded LoRA weights from: {model_path} making new Start Epoch: {start_epoch}" | |
) | |
print_demo_text(model, tokenizer) | |
if should_train: | |
with open("tweets.json", "r", encoding="utf-8") as f: | |
tweets = [f'{USERNAME}: {t["tweet"]["full_text"]}' for t in json.load(f)] | |
random.seed(42) | |
random.shuffle(tweets) | |
split_idx = int(len(tweets) * 0.9) | |
texts, tests = tweets[:split_idx], tweets[split_idx:] | |
dataset = TextDataset(texts, tokenizer) | |
# dataset = ThePileDataset(tokenizer) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=not isinstance(dataset, IterableDataset), | |
) | |
optimizer = AdamW(model.parameters(), lr=lr) | |
# optimizer = SGD(model.parameters(), lr=lr) | |
model.train() | |
accum_loss = 0.0 | |
accum_count = 0 | |
prev_loss = 0.0 | |
n_step = 0 | |
for epoch in range(start_epoch, num_epochs): | |
# Wrap the dataloader with tqdm for a progress bar | |
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}") | |
for i, batch in enumerate(progress_bar): | |
n_step += 1 | |
input_ids = batch["input_ids"].squeeze().to(device) | |
attention_mask = batch["attention_mask"].squeeze().to(device) | |
outputs = model( | |
input_ids=input_ids, attention_mask=attention_mask, labels=input_ids | |
) | |
loss = outputs.loss / batch_accumulation | |
loss.backward() | |
accum_loss += loss.item() | |
accum_count += 1 | |
if ((i + 1) % batch_accumulation == 0) or (i + 1 == len(dataloader)): | |
optimizer.step() | |
optimizer.zero_grad() | |
prev_loss += accum_loss - prev_loss | |
accum_loss -= accum_loss | |
accum_count -= accum_count | |
# print_demo_text(model, tokenizer) | |
# model_save_path = os.path.join( | |
# model_save_dir, | |
# f"epoch_{str(epoch).zfill(4)}_{str(n_step).zfill(10)}.pt", | |
# ) | |
# torch.save(extract_lora_weights(model), model_save_path) | |
# print(f"LoRA weights saved at: {model_save_path}") | |
progress_bar.set_postfix( | |
{ | |
"Prev Loss": prev_loss, | |
"Loss": ( | |
(accum_loss * batch_accumulation) | |
/ max(float(accum_count), 1.0) | |
), | |
"Accum": (batch_accumulation - ((i + 1) % batch_accumulation)), | |
} | |
) | |
# Save the LoRA weights after each epoch | |
model_save_path = os.path.join( | |
model_save_dir, | |
f"epoch_{str(epoch).zfill(4)}_{str(n_step).zfill(10)}.pt", | |
) | |
lora_weights = extract_lora_weights(model) | |
torch.save(lora_weights, model_save_path) | |
print(f"LoRA weights saved at: {model_save_path}") | |
print_demo_text(model, tokenizer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment