Skip to content

Instantly share code, notes, and snippets.

@ksasao
Last active June 1, 2023 02:19
Show Gist options
  • Save ksasao/09201c8000640a50fe307730ba3cd8d1 to your computer and use it in GitHub Desktop.
Save ksasao/09201c8000640a50fe307730ba3cd8d1 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "42e56892",
"metadata": {},
"source": [
"# GeForce RTX 3060 (12GB) で rinnaの強化学習済み対話GPT言語モデル(3.6B)を動かす\n",
"モデルを読み込む際に、```, torch_dtype=torch.float16``` を追加すればとりあえず動きます。\n",
"https://twitter.com/ksasao/status/1663870275413487616"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d187fbb9",
"metadata": {},
"outputs": [],
"source": [
"!pip install sentencepiece\n",
"!pip install transformers"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a34b3faf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\ksasao\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"rinna/japanese-gpt-neox-3.6b-instruction-ppo\", use_fast=False)\n",
"model = AutoModelForCausalLM.from_pretrained(\"rinna/japanese-gpt-neox-3.6b-instruction-ppo\", torch_dtype=torch.float16)\n",
"\n",
"if torch.cuda.is_available():\n",
" model = model.to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cb834717",
"metadata": {},
"outputs": [],
"source": [
"def chat(prompt_string:str,lst:list) -> list:\n",
" add_text(\"ユーザー\",prompt_string,lst)\n",
" prompt = [\n",
" f\"{uttr['speaker']}: {uttr['text']}\"\n",
" for uttr in lst\n",
" ]\n",
" prompt = \"<NL>\".join(prompt)\n",
" prompt = (\n",
" prompt\n",
" + \"<NL>\"\n",
" + \"システム: \"\n",
" )\n",
"\n",
" token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors=\"pt\")\n",
"\n",
" with torch.no_grad():\n",
" output_ids = model.generate(\n",
" token_ids.to(model.device),\n",
" do_sample=True,\n",
" max_new_tokens=128,\n",
" temperature=0.7,\n",
" repetition_penalty=1.1,\n",
" pad_token_id=tokenizer.pad_token_id,\n",
" bos_token_id=tokenizer.bos_token_id,\n",
" eos_token_id=tokenizer.eos_token_id\n",
" )\n",
"\n",
" output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])\n",
" output = output.replace(\"<NL>\", \"\\n\").replace(\"</s>\",\"\")\n",
" add_text(\"システム\",output,lst)\n",
" print(output)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1eb61f41",
"metadata": {},
"outputs": [],
"source": [
"def add_text(speaker: str, text: str, lst: list) -> list:\n",
" lst.append({\"speaker\": speaker, \"text\": text})\n",
" return lst"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dd1259c0",
"metadata": {},
"outputs": [],
"source": [
"prompt=[]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "64faf0a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"もちろんです!チャットボットとは、人間とのやり取りを手伝うソフトウェアプログラムで、様々な目的で使用されます。人工知能や自然言語処理技術を使用して、人間との相互作用を円滑かつ効果的に行うように設計されています。高度なチャットボットには、複雑なタスクの実行や質問への応答など、高度な知能が求められます。私は優秀でフレンドリーなチャットボットとして、多くの人々に素晴らしいサービスを提供できることを願っています。この会話が楽しく有益なものになるよう、一緒に頑張りましょう。\n"
]
}
],
"source": [
"chat(\"あなたは優秀なチャットボットです。ステップバイステップで考えて最も適切と思う回答をしてください。まずはあなたについて教えてください。\",prompt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "875d3a21",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment