Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Last active June 28, 2023 05:12
Show Gist options
  • Save CoffeeVampir3/4555447a867f003ac2ac57afafe67bba to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/4555447a867f003ac2ac57afafe67bba to your computer and use it in GitHub Desktop.
train lora notebook
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "aa00aabc-a8ca-4540-b312-83f012313c8b",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling\n",
"import math"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f3b7b26-33ea-443a-b062-4be7f31587f6",
"metadata": {},
"outputs": [],
"source": [
"model_path = \"models/airoboros-13b-gpt4-1.4\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55fe0c42-4c7c-47d9-96ed-c308ccde2130",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", load_in_4bit=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f9128be-eebd-4883-ba41-5e2c9bff3a5f",
"metadata": {},
"outputs": [],
"source": [
"from peft import (\n",
" LoraConfig,\n",
" get_peft_model,\n",
" prepare_model_for_kbit_training,\n",
" set_peft_model_state_dict\n",
")\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35322396-81ee-4ae4-b227-f395e1ed11af",
"metadata": {},
"outputs": [],
"source": [
"lora_rank = 32\n",
"lora_alpha = 64\n",
"lora_dropout = 0.05\n",
"\n",
"config = LoraConfig(\n",
" r=lora_rank,\n",
" lora_alpha=lora_alpha,\n",
" target_modules=[\"q_proj\", \"v_proj\"],\n",
" lora_dropout=lora_dropout,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\"\n",
")\n",
"\n",
"model.gradient_checkpointing_enable()\n",
"model = prepare_model_for_kbit_training(model)\n",
"lora_model = get_peft_model(model, config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cef5317c-2db2-4292-99c6-1c91725087bb",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"dataset = load_dataset('json', data_files='data/cleaned.json')\n",
"\n",
"td = dataset\n",
"data = td.map(lambda samples: tokenizer(samples[\"text\"]), batched=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a626ef51-cff1-459b-9daf-c19ba97d2d60",
"metadata": {},
"outputs": [],
"source": [
"from transformers import IntervalStrategy\n",
"import os\n",
"\n",
"os.makedirs(\"out\", exist_ok=True)\n",
"micro_batch_size = 16\n",
"batch_size = 256\n",
"gradient_accumulation_steps = batch_size // micro_batch_size\n",
"warmup_steps = 7\n",
"eval_steps = 100\n",
"epochs = 1\n",
"actual_lr = 6e-4\n",
"lr_scheduler_type = 'cosine_with_restarts'\n",
"trainer = Trainer(\n",
" model=lora_model,\n",
" train_dataset=data[\"train\"],\n",
" #eval_dataset=tokenized_datasets,\n",
" args=TrainingArguments(\n",
" save_strategy=IntervalStrategy.STEPS,\n",
" save_steps=30,\n",
" save_total_limit=5,\n",
" per_device_train_batch_size=micro_batch_size,\n",
" gradient_accumulation_steps=gradient_accumulation_steps,\n",
" warmup_steps=warmup_steps,\n",
" num_train_epochs=epochs,\n",
" learning_rate=actual_lr,\n",
" fp16=True,\n",
" optim='adamw_bnb_8bit',\n",
" logging_steps=5,\n",
" evaluation_strategy=\"no\",\n",
" #eval_steps=math.ceil(eval_steps / gradient_accumulation_steps),\n",
" lr_scheduler_type=lr_scheduler_type,\n",
" ddp_find_unused_parameters=None,\n",
" output_dir=\"out\",\n",
" ),\n",
" data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c7427144-098a-4605-ab0a-32d18c4f0d0b",
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de59b407-70cb-4bd1-96ef-8c0a79338aa0",
"metadata": {},
"outputs": [],
"source": [
"trainer.model.save_pretrained(\"out\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment