Last active
June 28, 2023 05:12
-
-
Save CoffeeVampir3/4555447a867f003ac2ac57afafe67bba to your computer and use it in GitHub Desktop.
train lora notebook
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "aa00aabc-a8ca-4540-b312-83f012313c8b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling\n", | |
| "import math" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "7f3b7b26-33ea-443a-b062-4be7f31587f6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model_path = \"models/airoboros-13b-gpt4-1.4\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "55fe0c42-4c7c-47d9-96ed-c308ccde2130", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tokenizer = AutoTokenizer.from_pretrained(model_path)\n", | |
| "model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", load_in_4bit=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "8f9128be-eebd-4883-ba41-5e2c9bff3a5f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from peft import (\n", | |
| " LoraConfig,\n", | |
| " get_peft_model,\n", | |
| " prepare_model_for_kbit_training,\n", | |
| " set_peft_model_state_dict\n", | |
| ")\n", | |
| "import torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "35322396-81ee-4ae4-b227-f395e1ed11af", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "lora_rank = 32\n", | |
| "lora_alpha = 64\n", | |
| "lora_dropout = 0.05\n", | |
| "\n", | |
| "config = LoraConfig(\n", | |
| " r=lora_rank,\n", | |
| " lora_alpha=lora_alpha,\n", | |
| " target_modules=[\"q_proj\", \"v_proj\"],\n", | |
| " lora_dropout=lora_dropout,\n", | |
| " bias=\"none\",\n", | |
| " task_type=\"CAUSAL_LM\"\n", | |
| ")\n", | |
| "\n", | |
| "model.gradient_checkpointing_enable()\n", | |
| "model = prepare_model_for_kbit_training(model)\n", | |
| "lora_model = get_peft_model(model, config)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "cef5317c-2db2-4292-99c6-1c91725087bb", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from datasets import load_dataset\n", | |
| "\n", | |
| "dataset = load_dataset('json', data_files='data/cleaned.json')\n", | |
| "\n", | |
| "td = dataset\n", | |
| "data = td.map(lambda samples: tokenizer(samples[\"text\"]), batched=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "a626ef51-cff1-459b-9daf-c19ba97d2d60", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from transformers import IntervalStrategy\n", | |
| "import os\n", | |
| "\n", | |
| "os.makedirs(\"out\", exist_ok=True)\n", | |
| "micro_batch_size = 16\n", | |
| "batch_size = 256\n", | |
| "gradient_accumulation_steps = batch_size // micro_batch_size\n", | |
| "warmup_steps = 7\n", | |
| "eval_steps = 100\n", | |
| "epochs = 1\n", | |
| "actual_lr = 6e-4\n", | |
| "lr_scheduler_type = 'cosine_with_restarts'\n", | |
| "trainer = Trainer(\n", | |
| " model=lora_model,\n", | |
| " train_dataset=data[\"train\"],\n", | |
| " #eval_dataset=tokenized_datasets,\n", | |
| " args=TrainingArguments(\n", | |
| " save_strategy=IntervalStrategy.STEPS,\n", | |
| " save_steps=30,\n", | |
| " save_total_limit=5,\n", | |
| " per_device_train_batch_size=micro_batch_size,\n", | |
| " gradient_accumulation_steps=gradient_accumulation_steps,\n", | |
| " warmup_steps=warmup_steps,\n", | |
| " num_train_epochs=epochs,\n", | |
| " learning_rate=actual_lr,\n", | |
| " fp16=True,\n", | |
| " optim='adamw_bnb_8bit',\n", | |
| " logging_steps=5,\n", | |
| " evaluation_strategy=\"no\",\n", | |
| " #eval_steps=math.ceil(eval_steps / gradient_accumulation_steps),\n", | |
| " lr_scheduler_type=lr_scheduler_type,\n", | |
| " ddp_find_unused_parameters=None,\n", | |
| " output_dir=\"out\",\n", | |
| " ),\n", | |
| " data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "c7427144-098a-4605-ab0a-32d18c4f0d0b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "trainer.train()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "de59b407-70cb-4bd1-96ef-8c0a79338aa0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "trainer.model.save_pretrained(\"out\")" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.10" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment