Skip to content

Instantly share code, notes, and snippets.

@bathtimefish
Last active July 17, 2023 04:38
Show Gist options
  • Save bathtimefish/c168e8795e29a6c9c0fa94142620a57f to your computer and use it in GitHub Desktop.
Save bathtimefish/c168e8795e29a6c9c0fa94142620a57f to your computer and use it in GitHub Desktop.
Finetuning Rinna3.6b the LLM for Japanese on WSL Ubuntu 22.04 LTS with CUDA 12.1 on Windows 11
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "11e15e35-cbd7-487f-9f73-baefff77bccb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting sentencepiece\n",
" Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m41.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting bitsandbytes\n",
" Obtaining dependency information for bitsandbytes from https://files.pythonhosted.org/packages/e5/66/24709e338106bd979756d0300737ad60fad4a784e8d521ce374c2f2b3bd8/bitsandbytes-0.40.1.post1-py3-none-any.whl.metadata\n",
" Downloading bitsandbytes-0.40.1.post1-py3-none-any.whl.metadata (9.8 kB)\n",
"Downloading bitsandbytes-0.40.1.post1-py3-none-any.whl (93.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m93.3/93.3 MB\u001b[0m \u001b[31m55.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
"\u001b[?25hInstalling collected packages: sentencepiece, bitsandbytes\n",
"Successfully installed bitsandbytes-0.40.1.post1 sentencepiece-0.1.99\n",
"Collecting scipy\n",
" Obtaining dependency information for scipy from https://files.pythonhosted.org/packages/14/f2/10fa23f0a6b9b2439c01579ae4a9b1849d4822e972515c8f92584bfda5e9/scipy-1.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata\n",
" Downloading scipy-1.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (59 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.1/59.1 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy<1.28.0,>=1.21.6 in /home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages (from scipy) (1.25.1)\n",
"Downloading scipy-1.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.3/36.3 MB\u001b[0m \u001b[31m82.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
"\u001b[?25hInstalling collected packages: scipy\n",
"Successfully installed scipy-1.11.1\n"
]
}
],
"source": [
"# Install modules\n",
"!pip install -Uqq git+https://github.com/huggingface/peft.git\n",
"!pip install -Uqq transformers datasets accelerate\n",
"!pip install sentencepiece bitsandbytes\n",
"!pip install scipy"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3cc7e512-20b7-4afc-b282-3ffec66356f9",
"metadata": {},
"outputs": [],
"source": [
"# 基本パラメータ\n",
"model_name = \"rinna/japanese-gpt-neox-3.6b\"\n",
"dataset = \"kunishou/databricks-dolly-15k-ja\"\n",
"peft_name = \"lora-rinna-3.6b\"\n",
"output_dir = \"lora-rinna-3.6b-dolly-15k-ja-results\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "65fcd45e-40cc-435e-95ca-055400e211da",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bf5d7b0268234718a87265e4f12e13b4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)okenizer_config.json: 0%| | 0.00/284 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "333a315a6a374604af2f47625fd899ac",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading spiece.model: 0%| | 0.00/786k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"# トークナイザーの準備\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7c421353-9732-4b82-8e4d-3e498f29eaf5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n",
"bos_token : <s> , 2\n",
"eos_token : </s> , 3\n",
"unk_token : [UNK] , 1\n",
"pad_token : [PAD] , 0\n"
]
}
],
"source": [
"# スペシャルトークンの確認\n",
"print(tokenizer.special_tokens_map)\n",
"print(\"bos_token :\", tokenizer.bos_token, \",\", tokenizer.bos_token_id)\n",
"print(\"eos_token :\", tokenizer.eos_token, \",\", tokenizer.eos_token_id)\n",
"print(\"unk_token :\", tokenizer.unk_token, \",\", tokenizer.unk_token_id)\n",
"print(\"pad_token :\", tokenizer.pad_token, \",\", tokenizer.pad_token_id)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "27a49a10-6cd1-417f-a7c2-464bf81a759b",
"metadata": {},
"outputs": [],
"source": [
"CUTOFF_LEN = 256 # コンテキスト長\n",
"\n",
"# トークナイズ\n",
"def tokenize(prompt, tokenizer):\n",
" result = tokenizer(\n",
" prompt,\n",
" truncation=True,\n",
" max_length=CUTOFF_LEN,\n",
" padding=False,\n",
" )\n",
" return {\n",
" \"input_ids\": result[\"input_ids\"],\n",
" \"attention_mask\": result[\"attention_mask\"],\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b9c649c1-7481-4614-89bd-743278feb194",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [3201, 634, 1304, 3], 'attention_mask': [1, 1, 1, 1]}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# トークナイズの動作確認\n",
"tokenize(\"hi there\", tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a170df55-7bcc-48df-843a-7e1c25104898",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6d8c859662714bb7965f2ffb4c93d01b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading readme: 0%| | 0.00/355 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading and preparing dataset json/kunishou--databricks-dolly-15k-ja to /home/btf/.cache/huggingface/datasets/kunishou___json/kunishou--databricks-dolly-15k-ja-e353786ca55017da/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "870ebc23a23b4fc3bc682b41fe70ae32",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data files: 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2704c7b4121645c389e457ad3b7bfed3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/17.1M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d3849f15e5fa4fefb4776ca59b6cbc22",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Extracting data files: 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset json downloaded and prepared to /home/btf/.cache/huggingface/datasets/kunishou___json/kunishou--databricks-dolly-15k-ja-e353786ca55017da/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96. Subsequent calls will reuse this data.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38d58345a7c946668f4405f5ec52e726",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"# データセットの準備\n",
"data = load_dataset(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "72ac2645-0d22-4a78-bf03-41a31aeb4acc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'category': 'information_extraction',\n",
" 'input': 'ステイルメイトとは、チェスにおいて、手番が回ってきたプレーヤーがチェックされておらず、合法的な手がない状態のことである。ステイルメイトの結果、引き分けとなる。終盤では、ステイルメイトは劣勢にあるプレイヤーが負けるのではなく、ゲームを引き分けることを可能にする戦術である[2]。より複雑なポジションでは、ステイルメイトはより稀で、通常は優勢側が不注意な場合にのみ成功する詐欺の形をとる[引用] ステイルメイトは終盤研究や他のチェスの問題においても共通のテーマである。\\n\\nステイルメイトが引き分けに統一されたのは19世紀である。それ以前は、ステイルメイトしているプレイヤーの勝利、引き分け、負けとみなされたり、反則となったり、ステイルメイトしているプレイヤーはターンを失うことになったりと、その扱いは様々であった。ステイルメイトのルールは、チェス以外のチャトランガ系ゲームごとに異なる。',\n",
" 'index': '5',\n",
" 'output': 'いいえ。\\nステイルメイトとは、引き分けた状態のことです。どちらがより多くの駒を捕獲したか、または優勢であるかは関係ない',\n",
" 'instruction': 'ステイルメイトの時に、私の方が多くの駒を持っていたら、私の勝ちですか?'}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# データセットの確認\n",
"data[\"train\"][5]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b26c717b-d208-4b47-b6e8-f497462a34bb",
"metadata": {},
"outputs": [],
"source": [
"# プロンプトテンプレートの準備\n",
"def generate_prompt(data_point):\n",
" if data_point[\"input\"]:\n",
" result = f\"\"\"### 指示:\n",
"{data_point[\"instruction\"]}\n",
"\n",
"### 入力:\n",
"{data_point[\"input\"]}\n",
"\n",
"### 回答:\n",
"{data_point[\"output\"]}\"\"\"\n",
" else:\n",
" result = f\"\"\"### 指示:\n",
"{data_point[\"instruction\"]}\n",
"\n",
"### 回答:\n",
"{data_point[\"output\"]}\"\"\"\n",
"\n",
" # 改行→<NL>\n",
" result = result.replace('\\n', '<NL>')\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f6aa74dd-9584-4101-8141-b03da8dcdbe5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"### 指示:<NL>ステイルメイトの時に、私の方が多くの駒を持っていたら、私の勝ちですか?<NL><NL>### 入力:<NL>ステイルメイトとは、チェスにおいて、手番が回ってきたプレーヤーがチェックされておらず、合法的な手がない状態のことである。ステイルメイトの結果、引き分けとなる。終盤では、ステイルメイトは劣勢にあるプレイヤーが負けるのではなく、ゲームを引き分けることを可能にする戦術である[2]。より複雑なポジションでは、ステイルメイトはより稀で、通常は優勢側が不注意な場合にのみ成功する詐欺の形をとる[引用] ステイルメイトは終盤研究や他のチェスの問題においても共通のテーマである。<NL><NL>ステイルメイトが引き分けに統一されたのは19世紀である。それ以前は、ステイルメイトしているプレイヤーの勝利、引き分け、負けとみなされたり、反則となったり、ステイルメイトしているプレイヤーはターンを失うことになったりと、その扱いは様々であった。ステイルメイトのルールは、チェス以外のチャトランガ系ゲームごとに異なる。<NL><NL>### 回答:<NL>いいえ。<NL>ステイルメイトとは、引き分けた状態のことです。どちらがより多くの駒を捕獲したか、または優勢であるかは関係ない\n"
]
}
],
"source": [
"# プロンプトテンプレートの確認\n",
"print(generate_prompt(data[\"train\"][5]))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7e90295c-1d06-4afd-918a-9d4dc68bb00d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/13015 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/2000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"VAL_SET_SIZE = 2000\n",
"\n",
"# 学習データと検証データの準備\n",
"train_val = data[\"train\"].train_test_split(\n",
" test_size=VAL_SET_SIZE, shuffle=True, seed=42\n",
")\n",
"train_data = train_val[\"train\"]\n",
"val_data = train_val[\"test\"]\n",
"train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x), tokenizer))\n",
"val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x), tokenizer))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6a1532e6-bf38-4828-9dfd-989940cf90f9",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "84f50c4899e24b6db87d61d583da451a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)lve/main/config.json: 0%| | 0.00/534 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dc067a45729946869eaea2de34d331df",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading model.safetensors: 0%| | 0.00/7.37G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import AutoModelForCausalLM\n",
"\n",
"# モデルの準備\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" load_in_8bit=True,\n",
" device_map=\"auto\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "802f3c82-f8ce-4db6-90b1-9e6f70f31bc8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 3,244,032 || all params: 3,610,489,344 || trainable%: 0.08985020286491116\n"
]
}
],
"source": [
"from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType\n",
"\n",
"# LoRAのパラメータ\n",
"lora_config = LoraConfig(\n",
" r= 8, \n",
" lora_alpha=16,\n",
" target_modules=[\"query_key_value\"],\n",
" lora_dropout=0.05,\n",
" bias=\"none\",\n",
" task_type=TaskType.CAUSAL_LM\n",
")\n",
"\n",
"# モデルの前処理\n",
"model = prepare_model_for_int8_training(model)\n",
"\n",
"# LoRAモデルの準備\n",
"model = get_peft_model(model, lora_config)\n",
"\n",
"# 学習可能パラメータの確認\n",
"model.print_trainable_parameters()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "fc748ce8-c517-4dac-b21a-90758ab76302",
"metadata": {},
"outputs": [],
"source": [
"import transformers\n",
"eval_steps = 200\n",
"save_steps = 200\n",
"logging_steps = 20\n",
"\n",
"# トレーナーの準備\n",
"trainer = transformers.Trainer(\n",
" model=model,\n",
" train_dataset=train_data,\n",
" eval_dataset=val_data,\n",
" args=transformers.TrainingArguments(\n",
" num_train_epochs=3,\n",
" learning_rate=3e-4,\n",
" logging_steps=logging_steps,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" eval_steps=eval_steps,\n",
" save_steps=save_steps,\n",
" output_dir=output_dir,\n",
" report_to=\"none\",\n",
" save_total_limit=3,\n",
" push_to_hub=False,\n",
" auto_find_batch_size=True\n",
" ),\n",
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "89b352b1-0d62-43c3-9f29-dec1ff6ceadf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py:230: UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead. (Triggered internally at ../aten/src/ATen/native/TensorCompare.cpp:493.)\n",
" attn_scores = torch.where(causal_mask, attn_scores, mask_value)\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='4881' max='4881' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [4881/4881 1:49:09, Epoch 3/3]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>200</td>\n",
" <td>1.975400</td>\n",
" <td>1.996691</td>\n",
" </tr>\n",
" <tr>\n",
" <td>400</td>\n",
" <td>1.986400</td>\n",
" <td>1.968600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>600</td>\n",
" <td>2.031600</td>\n",
" <td>1.949893</td>\n",
" </tr>\n",
" <tr>\n",
" <td>800</td>\n",
" <td>1.946100</td>\n",
" <td>1.939998</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1000</td>\n",
" <td>1.860500</td>\n",
" <td>1.934100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1200</td>\n",
" <td>1.880600</td>\n",
" <td>1.929003</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1400</td>\n",
" <td>1.885300</td>\n",
" <td>1.920053</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1600</td>\n",
" <td>1.908900</td>\n",
" <td>1.919520</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1800</td>\n",
" <td>1.916600</td>\n",
" <td>1.918674</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2000</td>\n",
" <td>1.863200</td>\n",
" <td>1.918816</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2200</td>\n",
" <td>1.851900</td>\n",
" <td>1.913040</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2400</td>\n",
" <td>1.847900</td>\n",
" <td>1.916626</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2600</td>\n",
" <td>1.834800</td>\n",
" <td>1.908343</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2800</td>\n",
" <td>1.841200</td>\n",
" <td>1.912908</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3000</td>\n",
" <td>1.927900</td>\n",
" <td>1.910244</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3200</td>\n",
" <td>1.818600</td>\n",
" <td>1.906890</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3400</td>\n",
" <td>1.785800</td>\n",
" <td>1.911840</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3600</td>\n",
" <td>1.844500</td>\n",
" <td>1.912773</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3800</td>\n",
" <td>1.847500</td>\n",
" <td>1.911637</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4000</td>\n",
" <td>1.817400</td>\n",
" <td>1.908101</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4200</td>\n",
" <td>1.768700</td>\n",
" <td>1.911005</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4400</td>\n",
" <td>1.787500</td>\n",
" <td>1.906800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4600</td>\n",
" <td>1.793100</td>\n",
" <td>1.908361</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4800</td>\n",
" <td>1.755400</td>\n",
" <td>1.910091</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
]
}
],
"source": [
"# 学習の実行\n",
"model.config.use_cache = False\n",
"trainer.train() \n",
"model.config.use_cache = True\n",
"\n",
"# LoRAモデルの保存\n",
"trainer.model.save_pretrained(peft_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91274393-0b47-4785-b2de-7b1b939ce5bc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment