Skip to content

Instantly share code, notes, and snippets.

@billju
Created September 23, 2023 08:53
Show Gist options
  • Save billju/a0545cf8388635b6e7ff2d9f2074b2fa to your computer and use it in GitHub Desktop.
Save billju/a0545cf8388635b6e7ff2d9f2074b2fa to your computer and use it in GitHub Desktop.
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
from trl import SFTTrainer
tokenizer = AutoTokenizer.from_pretrained('TinyPixel/Llama-2-7B-bf16-sharded')
tokenizer.pad_token_id = 0
tokenizer.padding_side = 'left'
train_dataset = load_dataset('json', data_files='alpaca_gpt4_data_zh.json', split='train')
model = AutoModelForCausalLM.from_pretrained(
'TinyPixel/Llama-2-7B-bf16-sharded',
device_map='auto',
load_in_4bit=True,
torch_dtype=torch.float16
)
model.resize_token_embedding(len(tokenizer))
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, peft_config=LoraConfig(
r=8,
lora_alpha=32,
target_modules=['q_proj','v_proj'],
lora_dropout=0.05,
bias='none',
task_type='CAUSAL_LM',
))
model.print_trainable_parameters()
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=1024,
tokenizer=tokenizer,
packing=True,
)
trainer.train()
trainer.save_model('lora')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment