Created
August 20, 2023 09:38
-
-
Save simeneide/80aa37108474aa32b82cb7258778287b to your computer and use it in GitHub Desktop.
Multi-gpu-training with lora and 8bit failed
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% LOAD MODEL OBJECTS | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", load_in_8bit=True) | |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training | |
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=4, lora_alpha=32, lora_dropout=0.1) | |
model = prepare_model_for_kbit_training(model) | |
model = get_peft_model(model, peft_config) | |
setattr(model, 'model_parallel', True) | |
setattr(model, 'is_parallelizable', True) | |
print(model.print_trainable_parameters()) | |
#%% | |
import pytorch_lightning as pl | |
class LitAutoEncoder(pl.LightningModule): | |
def __init__(self, base_model): | |
super().__init__() | |
self.base_model = base_model | |
def training_step(self, batch, batch_idx): | |
# training_step defines the train loop. | |
loss = self.base_model(**batch, labels=batch["input_ids"]) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
# training_step defines the train loop. | |
loss = self.base_model(**batch, labels=batch["input_ids"]) | |
return loss | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |
return optimizer | |
pl_model = LitAutoEncoder(model) | |
#%% Load Data | |
from datasets import load_dataset | |
import torch | |
class Dataset(torch.utils.data.Dataset): | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.dataset = load_dataset("truthful_qa", "generation")['validation'] | |
def __getitem__(self, idx): | |
text = self.dataset[idx]['best_answer'] | |
encoded = self.tokenizer(text, return_tensors="pt", padding="max_length", max_length=10, truncation=True) | |
return encoded | |
def __len__(self): | |
return len(self.dataset) | |
ds = Dataset(tokenizer) | |
from torch.utils.data import DataLoader | |
dl = DataLoader(ds, batch_size=2) | |
#%% | |
batch = next(iter(dl)) | |
pl_model.training_step(batch=batch, batch_idx=1) | |
# %% | |
# Initialize a trainer | |
from pytorch_lightning import Trainer | |
trainer = Trainer(devices=2, max_epochs=1) | |
trainer.fit(pl_model, dl) | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment