Skip to content

Instantly share code, notes, and snippets.

@bathtimefish
Last active July 17, 2023 04:38
Show Gist options
  • Save bathtimefish/c168e8795e29a6c9c0fa94142620a57f to your computer and use it in GitHub Desktop.
Save bathtimefish/c168e8795e29a6c9c0fa94142620a57f to your computer and use it in GitHub Desktop.
Finetuning Rinna3.6b the LLM for Japanese on WSL Ubuntu 22.04 LTS with CUDA 12.1 on Windows 11
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "289cd868-a0a9-49ad-9a4e-f6fdadefce5e",
"metadata": {},
"outputs": [],
"source": [
"# 基本パラメータ\n",
"model_name = \"rinna/japanese-gpt-neox-3.6b\"\n",
"dataset = \"kunishou/databricks-dolly-15k-ja\"\n",
"peft_name = \"lora-rinna-3.6b\"\n",
"output_dir = \"lora-rinna-3.6b-dolly-15k-ja-results\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ccd8c3c8-db46-4bf6-bb3c-4b473e157008",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PeftModelForCausalLM(\n",
" (base_model): LoraModel(\n",
" (model): GPTNeoXForCausalLM(\n",
" (gpt_neox): GPTNeoXModel(\n",
" (embed_in): Embedding(32000, 2816)\n",
" (layers): ModuleList(\n",
" (0-35): 36 x GPTNeoXLayer(\n",
" (input_layernorm): LayerNorm((2816,), eps=1e-05, elementwise_affine=True)\n",
" (post_attention_layernorm): LayerNorm((2816,), eps=1e-05, elementwise_affine=True)\n",
" (attention): GPTNeoXAttention(\n",
" (rotary_emb): RotaryEmbedding()\n",
" (query_key_value): Linear8bitLt(\n",
" in_features=2816, out_features=8448, bias=True\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.05, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=2816, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=8448, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (dense): Linear8bitLt(in_features=2816, out_features=2816, bias=True)\n",
" )\n",
" (mlp): GPTNeoXMLP(\n",
" (dense_h_to_4h): Linear8bitLt(in_features=2816, out_features=11264, bias=True)\n",
" (dense_4h_to_h): Linear8bitLt(in_features=11264, out_features=2816, bias=True)\n",
" (act): GELUActivation()\n",
" )\n",
" )\n",
" )\n",
" (final_layer_norm): LayerNorm((2816,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (embed_out): Linear(in_features=2816, out_features=32000, bias=False)\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"from peft import PeftModel, PeftConfig\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"# モデルの準備\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" load_in_8bit=True,\n",
" device_map=\"auto\",\n",
")\n",
"\n",
"# トークナイザーの準備\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n",
"\n",
"# LoRAモデルの準備\n",
"model = PeftModel.from_pretrained(\n",
" model, \n",
" peft_name, \n",
" device_map=\"auto\"\n",
")\n",
"\n",
"# 評価モード\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "142a8f9e-f815-4c11-8e81-6e8b8f5e573b",
"metadata": {},
"outputs": [],
"source": [
"# プロンプトテンプレートの準備\n",
"def generate_prompt(data_point):\n",
" if data_point[\"input\"]:\n",
" result = f\"\"\"### 指示:\n",
"{data_point[\"instruction\"]}\n",
"\n",
"### 入力:\n",
"{data_point[\"input\"]}\n",
"\n",
"### 回答:\n",
"\"\"\"\n",
" else:\n",
" result = f\"\"\"### 指示:\n",
"{data_point[\"instruction\"]}\n",
"\n",
"### 回答:\n",
"\"\"\"\n",
"\n",
" # 改行→<NL>\n",
" result = result.replace('\\n', '<NL>')\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "76d20987-b203-4696-aa8b-8799bbb225ba",
"metadata": {},
"outputs": [],
"source": [
"# テキスト生成関数の定義\n",
"def generate(instruction,input=None,maxTokens=256):\n",
" # 推論\n",
" prompt = generate_prompt({'instruction':instruction,'input':input})\n",
" input_ids = tokenizer(prompt, \n",
" return_tensors=\"pt\", \n",
" truncation=True, \n",
" add_special_tokens=False).input_ids.cuda()\n",
" outputs = model.generate(\n",
" input_ids=input_ids, \n",
" max_new_tokens=maxTokens, \n",
" do_sample=True,\n",
" temperature=0.7, \n",
" top_p=0.75, \n",
" top_k=40, \n",
" no_repeat_ngram_size=2,\n",
" )\n",
" outputs = outputs[0].tolist()\n",
" print(tokenizer.decode(outputs))\n",
"\n",
" # EOSトークンにヒットしたらデコード完了\n",
" if tokenizer.eos_token_id in outputs:\n",
" eos_index = outputs.index(tokenizer.eos_token_id)\n",
" decoded = tokenizer.decode(outputs[:eos_index])\n",
"\n",
" # レスポンス内容のみ抽出\n",
" sentinel = \"### 回答:\"\n",
" sentinelLoc = decoded.find(sentinel)\n",
" if sentinelLoc >= 0:\n",
" result = decoded[sentinelLoc+len(sentinel):]\n",
" print(result.replace(\"<NL>\", \"\\n\")) # <NL>→改行\n",
" else:\n",
" print('Warning: Expected prompt template to be emitted. Ignoring output.')\n",
" else:\n",
" print('Warning: no <eos> detected ignoring output')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3c4671dc-8d65-4464-a4b1-8c98da972a72",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n",
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.\n",
"/home/btf/.pyenv/versions/3.10.12/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py:230: UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead. (Triggered internally at ../aten/src/ATen/native/TensorCompare.cpp:493.)\n",
" attn_scores = torch.where(causal_mask, attn_scores, mask_value)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"### 指示:<NL>自然言語処理とは?<NL><NL>### 回答:<NL>言語モデルとは、自然な言語を生成するために使用される数学的モデルです。自然語処理は、人間が話す自然で自然に感じられる言語で、コンピューターが理解できる言語に変換するプロセスです</s>\n",
"\n",
"言語モデルとは、自然な言語を生成するために使用される数学的モデルです。自然語処理は、人間が話す自然で自然に感じられる言語で、コンピューターが理解できる言語に変換するプロセスです\n"
]
}
],
"source": [
"# 推論を実行する\n",
"generate(\"自然言語処理とは?\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a67a02a0-dea4-4cdd-9e48-108ef0065aa3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"### 指示:<NL>日本の首都は?<NL><NL>### 回答:<NL>東京</s>\n",
"\n",
"東京\n"
]
}
],
"source": [
"generate(\"日本の首都は?\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a1c4caa9-bba9-4fd7-8773-1a4ea8be5971",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"### 指示:<NL>ステイルメイトとはなんですか?<NL><NL>### 回答:<NL>チェスでは、ステイルズメイトは引き分けを意味します。 チェスの引き分けは、ゲームの開始時に、双方がチェックメイトになるまで駒を動かさず、相手が動かした駒が自分の駒にチェックされるまで続く。</s>\n",
"\n",
"チェスでは、ステイルズメイトは引き分けを意味します。 チェスの引き分けは、ゲームの開始時に、双方がチェックメイトになるまで駒を動かさず、相手が動かした駒が自分の駒にチェックされるまで続く。\n"
]
}
],
"source": [
"generate(\"ステイルメイトとはなんですか?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d714cea-7b47-4378-b895-6749545f2259",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment