Last active
September 14, 2023 12:03
-
-
Save kishida/c20f82dad6732336cb21b575ad91b987 to your computer and use it in GitHub Desktop.
lora-line1.7b-dolly
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/kishida/53855070e830b2f379b7cab211ae23a4/lora-ipynb.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"このnotebookは`line-corporation/japanese-large-lm-1.7b`のモデルを`kunishou/databricks-dolly-15k-ja`のデータセットを用いてLoRA tuningするためのコードの例です。以下の例では、学習を1 epochを行います。T4 GPUで実行すると30分ほどかかります。\n", | |
"\n", | |
"- モデル:https://huggingface.co/line-corporation/japanese-large-lm-1.7b\n", | |
"- データ:https://github.com/kunishou/databricks-dolly-15k-ja\n", | |
"\n", | |
"\n", | |
"また、ここで用いている設定は暫定的なもので、必要に応じて調整してください。\n", | |
"\n", | |
"stockmark/gpt-neox-japanese-1.4b LoRAのnotebookを変更したものです。 \n", | |
"https://huggingface.co/stockmark/gpt-neox-japanese-1.4b/blob/main/notebooks/LoRA.ipynb" | |
], | |
"metadata": { | |
"id": "BPGgCZtMdMsv" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# ライブラリのインストール" | |
], | |
"metadata": { | |
"id": "hCZH9e6EcZyj" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "cmn52bx3v5Ha" | |
}, | |
"outputs": [], | |
"source": [ | |
"!python3 -m pip install -U pip\n", | |
"!python3 -m pip install transformers accelerate datasets peft sentencepiece" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 準備" | |
], | |
"metadata": { | |
"id": "4t3Cqs9_ce3J" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"import datasets\n", | |
"from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments\n", | |
"from peft import get_peft_model, LoraConfig, TaskType, PeftModel, PeftConfig\n", | |
"\n", | |
"model_name = \"line-corporation/japanese-large-lm-1.7b\"\n", | |
"peft_model_name = \"peft_model\"\n", | |
"\n", | |
"prompt_template_cqa = \"\"\"ユーザー: 次の情報を元に質問に答えてください。{input}\n", | |
"システム: わかりました。\n", | |
"ユーザー: {instruction}\n", | |
"システム: \"\"\"\n", | |
"prompt_template_oqa = \"\"\"ユーザー: {instruction}\n", | |
"システム: \"\"\"\n", | |
"\n", | |
"def encode(sample):\n", | |
" if (sample[\"input\"]):\n", | |
" prompt = prompt_template_cqa.format(instruction=sample[\"instruction\"], input=sample[\"input\"])\n", | |
" else:\n", | |
" prompt = prompt_template_oqa.format(instruction=sample[\"instruction\"])\n", | |
" target = sample[\"output\"] + tokenizer.eos_token\n", | |
" input_ids_prompt, input_ids_target = tokenizer([prompt, target]).input_ids\n", | |
" input_ids = input_ids_prompt + input_ids_target\n", | |
" labels = input_ids.copy()\n", | |
" labels[:len(input_ids_prompt)] = [-100] * len(input_ids_prompt)\n", | |
" return {\"input_ids\": input_ids, \"labels\": labels}\n", | |
"\n", | |
"def get_collator(tokenizer, max_length):\n", | |
" def collator(batch):\n", | |
" batch = [{ key: value[:max_length] for key, value in sample.items() } for sample in batch ]\n", | |
" batch = tokenizer.pad(batch, padding=True)\n", | |
" batch[\"labels\"] = [ e + [-100] * (len(batch[\"input_ids\"][0]) - len(e)) for e in batch[\"labels\"] ]\n", | |
" batch = { key: torch.tensor(value) for key, value in batch.items() }\n", | |
" return batch\n", | |
"\n", | |
" return collator\n" | |
], | |
"metadata": { | |
"id": "hNdYMGMRzAVn" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# データセットとモデルの準備\n" | |
], | |
"metadata": { | |
"id": "UqXxPjJ_cliu" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# prepare dataset\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n", | |
"\n", | |
"dataset_name = \"kunishou/databricks-dolly-15k-ja\"\n", | |
"dataset = datasets.load_dataset(dataset_name)\n", | |
"dataset = dataset.map(encode)\n", | |
"dataset = dataset[\"train\"].train_test_split(0.2)\n", | |
"train_dataset = dataset[\"train\"]\n", | |
"val_dataset = dataset[\"test\"]\n", | |
"\n", | |
"# load model\n", | |
"base_model = AutoModelForCausalLM.from_pretrained(model_name, device_map={\"\": 0}, torch_dtype=torch.float16)\n", | |
"\n", | |
"peft_config = LoraConfig(\n", | |
" task_type=TaskType.CAUSAL_LM,\n", | |
" inference_mode=False,\n", | |
" target_modules=[\"c_attn\"],\n", | |
" r=16,\n", | |
" lora_alpha=32,\n", | |
" lora_dropout=0.05\n", | |
")\n", | |
"\n", | |
"model = get_peft_model(base_model, peft_config)\n", | |
"model.print_trainable_parameters()" | |
], | |
"metadata": { | |
"id": "ZWdN-p7t0Grk" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# LoRA tuning" | |
], | |
"metadata": { | |
"id": "XCrdVAJYc88c" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"training_args = TrainingArguments(\n", | |
" output_dir=\"./train_results\",\n", | |
" learning_rate=2e-4,\n", | |
" per_device_train_batch_size=4,\n", | |
" gradient_accumulation_steps=4,\n", | |
" per_device_eval_batch_size=16,\n", | |
" num_train_epochs=1,\n", | |
" logging_strategy='steps',\n", | |
" logging_steps=10,\n", | |
" save_strategy='epoch',\n", | |
" evaluation_strategy='epoch',\n", | |
" load_best_model_at_end=True,\n", | |
" metric_for_best_model=\"eval_loss\",\n", | |
" greater_is_better=False,\n", | |
" save_total_limit=2\n", | |
")\n", | |
"\n", | |
"trainer = Trainer(\n", | |
" model=model,\n", | |
" args=training_args,\n", | |
" train_dataset=train_dataset,\n", | |
" eval_dataset=val_dataset,\n", | |
" data_collator=get_collator(tokenizer, 512)\n", | |
")\n", | |
"\n", | |
"trainer.train()\n", | |
"model = trainer.model\n", | |
"model.save_pretrained(peft_model_name)" | |
], | |
"metadata": { | |
"id": "4LH9tOCTJVk1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 学習したモデルのロード" | |
], | |
"metadata": { | |
"id": "ORgzOPAqdEZR" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = PeftModel.from_pretrained(base_model, peft_model_name)" | |
], | |
"metadata": { | |
"id": "yrExyO9EOvzR" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 推論" | |
], | |
"metadata": { | |
"id": "-dttR6tkdG0k" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"prompt = prompt_template_oqa.format(instruction=\"日本で人気のスポーツは?\")\n", | |
"\n", | |
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n", | |
"with torch.no_grad():\n", | |
" tokens = model.generate(\n", | |
" **inputs,\n", | |
" max_new_tokens=128,\n", | |
" repetition_penalty=1.1\n", | |
" )\n", | |
"\n", | |
"output = tokenizer.decode(tokens[0], skip_special_tokens=True)\n", | |
"print(output)" | |
], | |
"metadata": { | |
"id": "pC5t9F1GJuFN" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "pEM50stxfffV" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment