Skip to content

Instantly share code, notes, and snippets.

@simeneide
Created August 20, 2023 09:38
Show Gist options
  • Save simeneide/80aa37108474aa32b82cb7258778287b to your computer and use it in GitHub Desktop.
Save simeneide/80aa37108474aa32b82cb7258778287b to your computer and use it in GitHub Desktop.
Multi-gpu-training with lora and 8bit failed
#%% LOAD MODEL OBJECTS
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", load_in_8bit=True)
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=4, lora_alpha=32, lora_dropout=0.1)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
setattr(model, 'model_parallel', True)
setattr(model, 'is_parallelizable', True)
print(model.print_trainable_parameters())
#%%
import pytorch_lightning as pl
class LitAutoEncoder(pl.LightningModule):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
loss = self.base_model(**batch, labels=batch["input_ids"])
return loss
def validation_step(self, batch, batch_idx):
# training_step defines the train loop.
loss = self.base_model(**batch, labels=batch["input_ids"])
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
pl_model = LitAutoEncoder(model)
#%% Load Data
from datasets import load_dataset
import torch
class Dataset(torch.utils.data.Dataset):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.dataset = load_dataset("truthful_qa", "generation")['validation']
def __getitem__(self, idx):
text = self.dataset[idx]['best_answer']
encoded = self.tokenizer(text, return_tensors="pt", padding="max_length", max_length=10, truncation=True)
return encoded
def __len__(self):
return len(self.dataset)
ds = Dataset(tokenizer)
from torch.utils.data import DataLoader
dl = DataLoader(ds, batch_size=2)
#%%
batch = next(iter(dl))
pl_model.training_step(batch=batch, batch_idx=1)
# %%
# Initialize a trainer
from pytorch_lightning import Trainer
trainer = Trainer(devices=2, max_epochs=1)
trainer.fit(pl_model, dl)
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment