Created
May 16, 2023 15:40
-
-
Save adibMosharrof/78b46b677750b4df72f805497a48908d to your computer and use it in GitHub Desktop.
PEFT training example code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
import transformers | |
from peft import ( | |
LoraConfig, | |
PeftConfig, | |
PeftModel, | |
get_peft_model, | |
prepare_model_for_int8_training, | |
) | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
HfArgumentParser, | |
TrainingArguments, | |
Trainer, | |
) | |
import torch | |
from dataclasses import dataclass, field | |
from typing import Optional | |
from dataclass_csv import DataclassReader | |
from torch.utils.data import Dataset, DataLoader | |
from enum import Enum | |
class SpecialTokens(str, Enum): | |
begin_target = "<|begintarget|>" | |
end_target = "<|endtarget|>" | |
begin_context = "<|begincontext|>" | |
end_context = "<|endcontext|>" | |
system = "<|system|>" | |
user = "<|user|>" | |
begin_last_user_utterance = "<|beginlastuserutterance|>" | |
end_last_user_utterance = "<|endlastuserutterance|>" | |
begin_dsts = "<|begindsts|>" | |
end_dsts = "<|enddsts|>" | |
begin_dst = "<|begindst|>" | |
end_dst = "<|enddst|>" | |
begin_belief = "<|beginbelief|>" | |
end_belief = "<|endbelief|>" | |
begin_response = "<|beginresponse|>" | |
end_response = "<|endresponse|>" | |
begin_action = "<|beginaction|>" | |
end_action = "<|endaction|>" | |
begin_user_action = "<|beginuseraction|>" | |
end_user_action = "<|enduseraction|>" | |
sys_actions = "<|sysactions|>" | |
begin_intent = "<|beginintent|>" | |
end_intent = "<|endintent|>" | |
begin_requested_slots = "<|beginrequestedslots|>" | |
end_requested_slots = "<|endrequestedslots|>" | |
pad_token = "<|pad|>" | |
bos_token = "<|startoftext|>" | |
@classmethod | |
def list(cls): | |
return [c.value for c in cls] | |
import urllib | |
import csv | |
import codecs | |
model_name = "EleutherAI/gpt-j-6B" | |
model_name = "aleksickx/llama-7b-hf" | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
pad_token=SpecialTokens.pad_token.value, | |
bos_token=SpecialTokens.bos_token.value, | |
eos_token=SpecialTokens.end_target.value, | |
additional_special_tokens=SpecialTokens.list(), | |
) | |
@dataclass | |
class DataModel: | |
dialog_id: str | |
turn_id: str | |
context: str | |
target: str | |
class TodDataSet(Dataset): | |
def __init__( | |
self, | |
data, | |
): | |
self.data = data | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
return self.data[idx] | |
path = "data.csv" | |
# path = "https://gist.githubusercontent.com/adibMosharrof/d3ff320381d4aeae9b57833500a58536/raw/ccf65c35202df249cc0e01c1d01c9384470de2d0/data" | |
# import requests | |
# res = urllib.request.urlopen(path) | |
# data = list(csv.reader(codecs.iterdecode(res,"utf-8"))) | |
with open(path) as f: | |
reader = DataclassReader(f, DataModel) | |
data = [r for r in reader] | |
# data = data[1:100] | |
# print(data) | |
split = int(len(data) * 0.8) | |
train_dataset = TodDataSet(data[:split]) | |
test_dataset = TodDataSet(data[split:]) | |
def tokenize_text(text): | |
return tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
padding="max_length", | |
max_length=512, | |
) | |
def collate_fn(batch): | |
input_tokens = [] | |
mask = [] | |
for item in batch: | |
row = tokenize_text(item.context + item.target) | |
input_tokens.append(row["input_ids"][0]) | |
mask.append(row["attention_mask"][0]) | |
return { | |
"input_ids": torch.stack(input_tokens), | |
"attention_mask": torch.stack(mask), | |
"labels": torch.stack(input_tokens), | |
} | |
def test_collate_fn(batch): | |
input_tokens = [] | |
mask = [] | |
target = [] | |
for item in batch: | |
row = tokenize_text(item.context) | |
input_tokens.append(row["input_ids"][0]) | |
mask.append(row["attention_mask"][0]) | |
target.append(item.target) | |
return { | |
"input_ids": torch.stack(input_tokens), | |
"attention_mask": torch.stack(mask), | |
# "target_txt": target, | |
} | |
def test_dataloader(): | |
return DataLoader( | |
test_dataset, batch_size=10, collate_fn=test_collate_fn, pin_memory=True | |
) | |
def train_dataloader(): | |
return DataLoader( | |
test_dataset, batch_size=10, collate_fn=collate_fn, pin_memory=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
load_in_8bit=True, | |
device_map="auto" | |
# model_name | |
) | |
model.resize_token_embeddings(len(tokenizer)) | |
model.enable_input_require_grads() | |
model.gradient_checkpointing_enable() | |
config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM", | |
base_model_name_or_path=model_name, | |
modules_to_save=["wte", "lm_head"], | |
) | |
model = get_peft_model(model, config) | |
train_batch_size = 1 | |
training_args = TrainingArguments( | |
output_dir="output", | |
num_train_epochs=1, | |
save_total_limit=5, | |
per_device_train_batch_size=8, | |
warmup_steps=100, | |
weight_decay=0.01, | |
dataloader_drop_last=True, | |
# fp16=True, | |
logging_steps=5, | |
learning_rate=5e-4, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
data_collator=collate_fn, | |
) | |
model.config.use_cache = False | |
trainer.train() | |
batch = tokenizer( | |
"<|begincontext|><|user|>I am feeling hungry so I would like to find a place to eat.<|system|>Do you have a specific which you want the eating place to be located at?<|user|>I would like for it to be in San Jose.<|system|>Is there a specific cuisine type you enjoy, such as Mexican, Italian or something else?<|beginlastuserutterance|>I usually like eating the American type of food.<|endlastuserutterance|><|endcontext|>", | |
return_tensors="pt", | |
) | |
batch = {k: v.to("cuda") for k, v in batch.items()} | |
model.eval() | |
output_tokens = model.generate( | |
**batch, | |
max_new_tokens=256, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
print("\n\n", tokenizer.decode(output_tokens[0], skip_special_tokens=False)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment