Skip to content

Instantly share code, notes, and snippets.

@yuntan
Created March 3, 2024 13:42
Show Gist options
  • Save yuntan/a35a2d21625e883388f1548cbbd88da2 to your computer and use it in GitHub Desktop.
Save yuntan/a35a2d21625e883388f1548cbbd88da2 to your computer and use it in GitHub Desktop.
LLM fine-tuning
from datetime import datetime
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
from datasets import Dataset
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft-v2")
if torch.cuda.is_available():
model = model.to("cuda")
# LoRAのパラメータ
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["query_key_value"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
# モデルの前処理
# model = prepare_model_for_int8_training(model)
# LoRAモデルの準備
model = get_peft_model(model, lora_config)
model.config.use_cache = False
# 学習可能パラメータの確認
model.print_trainable_parameters()
tokenizer = AutoTokenizer.from_pretrained(
"rinna/japanese-gpt-neox-3.6b-instruction-sft-v2",
use_fast=False,
add_eos_token=True
)
def tokenize(prompt: str, max_length=256):
result = tokenizer(
prompt,
# truncation=True,
# max_length=max_length,
padding=False,
)
return {
"input_ids": result["input_ids"],
"attention_mask": result["attention_mask"],
}
def load_dataset():
with open("train/LknA8PMvmB4.txt") as f:
lines = f.readlines()
lines = [line.replace(":", ": ") for line in lines]
text = "<NL>".join(lines)
token = tokenize(text)
return [token]
eval_steps = 200
save_steps = 200
logging_steps = 20
# トレーナーの準備
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir="models",
num_train_epochs=32,
learning_rate=3e-4,
logging_steps=logging_steps,
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=eval_steps,
save_steps=save_steps,
report_to=["none"],
save_total_limit=3,
push_to_hub=False,
auto_find_batch_size=True
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
train_dataset=load_dataset(),
# eval_dataset=val_data,
)
# 学習の実行
trainer.train()
# LoRAモデルの保存
name = datetime.now().strftime("%Y-%m-%d %H%M%S")
trainer.model.save_pretrained(f"models/{name}")
ユーザー:<START_OF_STREAM>
ユーザー:あ、メルさんだメルさんだ、来た来た来た
ユーザー:それでは、夜空メルさんへの質問がある方は挙手にてお願いしまーす。
システム:なんでしょうか!
ユーザー:はいはいはい!じゃあ白上いいですか!それじゃあ、はいはい行きますえっとですね、アセロラジュースがお好きとのことなんですけども、
システム:はい
ユーザー:実は赤い飲み物なら何でもいいという噂を聞いたんですが本当でしょうか?
システム:トマトジュースオッケー、イチゴジュースオッケー、ザクロオッケー
ユーザー:おーなんでもいい
システム:オッケーだね!
ユーザー:アセロラジュースである必要性はないと?
システム:でもアセロラがいちばん好き!
ユーザー:なるほどアセロラがいちばんね、でも赤いものならなんでもオッケーと。
システム:なんでもいいねー、なんでもいいかも!
ユーザー:わりとなんでもオッケー、ふんふん
ユーザー:はい!大神ミオです!
システム:はい
ユーザー:天才バンパイアとのことなんですが、IQで言うとどれぐらいなんでしょうか?
システム:えーっと、IQですか?ま、その課程を説明しますと、
ユーザー:課程?
システム:あのムスカ大佐いるじゃないですか
ユーザー:いますね
システム:ムスカ300くらいあるらしいんですけど、あいつは3分しか待てないんですよ。でも、メルは、あの友達の遅刻2時間待ったことあって、だからあいつを超えています。そして、えーま大体1031、つまり天才ってこと!!!
ユーザー:ああー、なるほどー。メモメモ
システム:メモしてメモ!
ユーザー:ちなみにあのりんご何個分のIQとなりますか
システム:あ、りんご、つまり、えそうIQを持っていないということは、えー、1億31個分です!
ユーザー:IQないのにですか、りんご
システム:はい!
ユーザー:無は無ですよ
システム:りんごよりも偉い!あ、りんごジュースも飲めます!
ユーザー:あ、りんごジュースも飲める、りんごジュースもオッケー、メモメモメモメモ
システム:メモメモメモメモ
ユーザー:なるほど、じゃあ白上いいですか?
システム:なんでしょう!
ユーザー:あの今年1年でバンパイアとしてやってきたと思うんですけども、いちばんバンパイアらしかったなと思った行動は何でしょうか?
システム:バンパイアやらしかったこと?バンパイアらしくなかったことといえば、日光を浴びたことです
ユーザー:らしくないですね、大変らしくないですね。克服したってことですかね
システム:克服、ですかね、やっぱ日光を浴びることでセロトニンが分泌されて、健康になれると
ユーザー:なるほどとえじゃつまりあのジョジョのリオも超えたということですか?
システム:そういうことだ!つまりメルがバンパイア界にて最強っていうこと
ユーザー:最強、そういうことですか
システム:これがバンパイアらしいことです
ユーザー:ああ素晴らしいですね。なるほどー、なんかじゃあいろんな様々なことをもう克服してるってことですか
システム:そういうことです
ユーザー:逆に克服できてないことは何でしょうか?
システム:それは血を飲むことです。
ユーザー:血を飲むのはちょっとね、怖い
システム:怖いでしょ?
ユーザー:グロいもんね
システム:グロいよね?飲まれたくないよね?
ユーザー:飲まれたくないですね
システム:このままでいいってことで
ユーザー:なるほどー、人畜無害ですね。
システム:人畜無害バンパイア代表なんで
ユーザー:あ、はいお時間となりましたので以上で夜空メルさんの質問は終了となります。ありがとうございました。
システム:あざしたー!
ユーザー:<END_OF_STREAM>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment