Last active
September 1, 2024 22:02
-
-
Save kohya-ss/e23fa9a321dba07fabd1ef61eab6863c to your computer and use it in GitHub Desktop.
gradioでLLMを利用する簡易クライアント
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Apache License 2.0 | |
# 使用法は gist のコメントを見てください | |
import argparse | |
from typing import List, Optional, Union, Iterator | |
import llama_cpp | |
from llama_cpp.llama_chat_format import _convert_completion_to_chat, register_chat_completion_handler | |
import llama_cpp.llama_types as llama_types | |
from llama_cpp.llama import LogitsProcessorList, LlamaGrammar | |
from llama_cpp import Llama, llama_chat_format | |
import gradio as gr | |
debug_flag = False | |
@register_chat_completion_handler("command-r") | |
def command_r_chat_handler( | |
llama: Llama, | |
messages: List[llama_types.ChatCompletionRequestMessage], | |
functions: Optional[List[llama_types.ChatCompletionFunction]] = None, | |
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, | |
tools: Optional[List[llama_types.ChatCompletionTool]] = None, | |
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, | |
temperature: float = 0.2, | |
top_p: float = 0.95, | |
top_k: int = 40, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
stream: bool = False, | |
stop: Optional[Union[str, List[str]]] = [], | |
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, | |
max_tokens: Optional[int] = None, | |
presence_penalty: float = 0.0, | |
frequency_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
model: Optional[str] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
**kwargs, # type: ignore | |
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: | |
# bos_token = "<BOS_TOKEN>" | |
start_turn_token = "<|START_OF_TURN_TOKEN|>" | |
end_turn_token = "<|END_OF_TURN_TOKEN|>" | |
user_token = "<|USER_TOKEN|>" | |
chatbot_token = "<|CHATBOT_TOKEN|>" | |
system_token = "<|SYSTEM_TOKEN|>" | |
prompt = "" # bos_token # suppress warning | |
if len(messages) > 0 and messages[0]["role"] == "system": | |
prompt += start_turn_token + system_token + messages[0]["content"] + end_turn_token | |
messages = messages[1:] | |
for message in messages: | |
if message["role"] == "user": | |
prompt += start_turn_token + user_token + message["content"] + end_turn_token | |
elif message["role"] == "assistant": | |
prompt += start_turn_token + chatbot_token + message["content"] + end_turn_token | |
prompt += start_turn_token + chatbot_token | |
if debug_flag: | |
print(f"Prompt: {prompt}") | |
stop_tokens = [end_turn_token] # , bos_token] | |
return _convert_completion_to_chat( | |
llama.create_completion( | |
prompt=prompt, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
min_p=min_p, | |
typical_p=typical_p, | |
stream=stream, | |
stop=stop_tokens, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
repeat_penalty=repeat_penalty, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
model=model, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
), | |
stream=stream, | |
) | |
def generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
global stop_generating | |
stop_generating = False | |
output = prompt | |
if debug_flag: | |
print( | |
f"temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, repeat_penalty: {repeat_penalty}, max_tokens: {max_tokens}" | |
) | |
for chunk in llama( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repeat_penalty=repeat_penalty, | |
top_k=top_k, | |
stream=True, | |
): | |
if debug_flag: | |
print(chunk) | |
if stop_generating: | |
break | |
if "choices" in chunk and len(chunk["choices"]) > 0: | |
if "text" in chunk["choices"][0]: | |
text = chunk["choices"][0]["text"] | |
# check EOS_TOKEN | |
if text.endswith("<EOS_TOKEN>"): # llama.tokenizer.EOS_TOKEN): | |
output += text[: -len("<EOS_TOKEN>")] | |
yield output[len(prompt) :] | |
break | |
output += text | |
yield output[len(prompt) :] # remove prompt | |
def launch_completion(llama, listen=False): | |
# css = """ | |
# .prompt textarea {font-size:1.0em !important} | |
# """ | |
# with gr.Blocks(css=css) as demo: | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
# change font size | |
io_textbox = gr.Textbox( | |
label="Input/Output: Text may not be scrolled automatically. Shift+Enter to newline. テキストは自動スクロールしないことがあります。Shift+Enterで改行。", | |
placeholder="Enter your prompt here...", | |
interactive=True, | |
elem_classes=["prompt"], | |
autoscroll=True, | |
) | |
with gr.Row(): | |
generate_button = gr.Button("Generate") | |
stop_button = gr.Button("Stop", visible=False) | |
with gr.Row(): | |
max_tokens = gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max Tokens") | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P") | |
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty") | |
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K") | |
def generate_and_display(prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
output_generator = generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k) | |
for output in output_generator: | |
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=False), gr.update(visible=True) | |
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=True), gr.update(visible=False) | |
def stop_generation(): | |
globals().update(stop_generating=True) | |
return gr.update(visible=True), gr.update(visible=False) | |
generate_button.click( | |
generate_and_display, | |
inputs=[io_textbox, max_tokens, temperature, top_p, repeat_penalty, top_k], | |
outputs=[io_textbox, generate_button, stop_button], | |
show_progress=True, | |
) | |
stop_button.click( | |
stop_generation, | |
outputs=[generate_button, stop_button], | |
) | |
# add event to textbox to add new line on enter | |
io_textbox.submit( | |
lambda x: x + "\n", | |
inputs=[io_textbox], | |
outputs=[io_textbox], | |
) | |
demo.launch(server_name="0.0.0.0" if listen else None) | |
def launch_chat(llama, handler_name, listen=False): | |
def chat(message, history, system_prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
user_input = message | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
for message in history: | |
messages.append({"role": "user", "content": message[0]}) | |
messages.append({"role": "assistant", "content": message[1]}) | |
messages.append({"role": "user", "content": user_input}) | |
if debug_flag: | |
print(f"Messages: {messages}") | |
print( | |
f"System prompt: {system_prompt}, temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, repeat_penalty: {repeat_penalty}, max_tokens: {max_tokens}" | |
) | |
handler = llama_chat_format.get_chat_completion_handler(handler_name) | |
chat_completion_chunks = handler( | |
llama=llama, | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repeat_penalty=repeat_penalty, | |
top_k=int(top_k), | |
stream=True, | |
) | |
response = "" | |
for chunk in chat_completion_chunks: | |
if debug_flag: | |
print(chunk) | |
if "choices" in chunk and len(chunk["choices"]) > 0: | |
if "delta" in chunk["choices"][0]: | |
if "content" in chunk["choices"][0]["delta"]: | |
response += chunk["choices"][0]["delta"]["content"] | |
yield response | |
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter system prompt here...") | |
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens") | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P") | |
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty") | |
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K") | |
additional_inputs = [system_prompt, max_tokens, temperature, top_p, repeat_penalty, top_k] | |
chatbot = gr.ChatInterface(chat, additional_inputs=additional_inputs) | |
chatbot.launch(server_name="0.0.0.0" if listen else None) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-m", "--model", type=str, default=None, help="Model file path") | |
parser.add_argument("-ngl", "--n_gpu_layers", type=int, default=0, help="Number of GPU layers") | |
parser.add_argument("-c", "--n_ctx", type=int, default=2048, help="Context length") | |
parser.add_argument( | |
"-ch", | |
"--chat_handler", | |
type=str, | |
default="command-r", | |
help="Chat handler, e.g. command-r, mistral-instruct, alpaca, llama-3 etc. default: command-r", | |
) | |
parser.add_argument("--chat", action="store_true", help="Chat mode") | |
parser.add_argument("--listen", action="store_true", help="Listen mode") | |
parser.add_argument( | |
"-ts", "--tensor_split", type=str, default=None, help="Tensor split, float values separated by comma for each gpu" | |
) | |
parser.add_argument("--disable_mmap", action="store_true", help="Disable mmap") | |
parser.add_argument("--kv_cache_q8", action="store_true", help="Use quantized kv cache (Q8)") | |
parser.add_argument("--flash_attn", action="store_true", help="Use flash attention") | |
parser.add_argument("--debug", action="store_true", help="Debug mode") | |
args = parser.parse_args() | |
# tokenizer initialization doesn't seem to be needed | |
# print("Initializing tokenizer") | |
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
print(f"Initializing Llama. Model ID: {args.model}, N_GPU_LAYERS: {args.n_gpu_layers}, N_CTX: {args.n_ctx}") | |
# llama_tokenizer = LlamaHFTokenizer(tokenizer) | |
tensor_split = None if args.tensor_split is None else [float(x) for x in args.tensor_split.split(",")] | |
llama = Llama( | |
model_path=args.model, | |
n_gpu_layers=args.n_gpu_layers, | |
tensor_split=tensor_split, | |
n_ctx=args.n_ctx, | |
use_mmap=not args.disable_mmap, | |
type_k=llama_cpp.GGML_TYPE_Q8_0 if args.kv_cache_q8 else None, | |
type_v=llama_cpp.GGML_TYPE_Q8_0 if args.kv_cache_q8 else None, | |
flash_attn=args.flash_attn, | |
# tokenizer=llama_tokenizer, | |
# n_threads=n_threads, | |
) | |
debug_flag = args.debug | |
if args.chat: | |
print(f"Launching chat with handler: {args.chat_handler}") | |
launch_chat(llama, args.chat_handler, args.listen) | |
else: | |
print("Launching completion") | |
launch_completion(llama, args.listen) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
GGUF 形式のLLMを動かす簡易クライアントです。
最近の更新とか
kv cache の quantize と flash attention に対応しました。最新の llama-cpp-python が必要ですのでインストール済みなら
pip install -U llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu124
(末尾は CUDA バージョンによって変更)で更新してください。mmap を無効にできるようにしました。mmap 有効(デフォルト)だとVRAM だけでなくメイン RAM にもモデルをロードすることで再読み込みが速くなるらしいですが、その分メイン RAM を消費します。mmap 無効でメイン RAM の消費量が減ります。
インストール
extra-index
の末尾はインストールされている CUDA バージョンに合わせてください。cu121, cu122, cu123, cu124 が公式リポジトリには書かれています。CUDA のマイナーバージョン(末尾一桁)が違ってもだいたい動くことが多いです。pip install -U llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu122
でバージョンアップできます。スクリプトの使い方
たとえば以下のように指定します。
その後、ブラウザで http://localhost:7860/ を開いてください。
-m
にダウンロードしたモデルを指定します。複数ファイルある場合は000001
のファイル名を指定します。-ngl
でいくつのレイヤーを GPU で処理するか指定します。多くすると VRAM を使いますが、それだけ速くなります。また少なくするとメイン RAM の使用量が増え、CPU で処理する量が増えるので遅くなります。-c
でコンテキスト長を指定します。モデルが認識できる入力のトークン数です。長くするとそれだけメモリを使います。-ngl
と-c
は VRAM とメイン RAM の空きを見ながら決めてください。--chat
でチャットモードで起動します。デフォルトは文章の続きを生成する補完モードです。--chat_handler
でモデルのチャット形式を指定します。command-r
やmistral-instruct
、llama-3
などが指定できます。省略時はcommand-r
です。llama-cpp-python の chat handler が指定できます。--listen
オプションをつけると LAN 内の他の PC からアクセス可能になります。--disable_mmap
を指定すると mmap を使いません。メイン RAM に余裕がない場合は指定してください。--kv_cache_q8
で KV cache を quantize することでメモリ使用量を減らします。やや品質が低下するようですが、ほぼ体感できないためメモリに余裕がない場合はお勧めです。--flash_attn
で Flash Attention を有効にします。生成が速くなるらしいです。--debug
でデバッグ出力を表示します。--help
で表示される使用法も参考にしてください。-ngl
についてモデルロード時に
llm_load_tensors: offloaded 15/33 layers to GPU
のように表示されます。/
のあとの数値がモデルの全レイヤー数ですので、その数値以下の範囲で指定できます。モデルとダウンロード
各リポジトリには、量子化方法の違いでいくつかのバージョンが用意されています。Q3、Q4 などの数値は量子化ビット数で、K_M や K_S、IQ?_XS などは量子化手法の違いのようです。PPL Value が低いほど性能劣化が抑えられているようです。
000001-of-000002、000001-of-000002 などの連番のファイルをすべて落としてください。
合計ファイルサイズが VRAM サイズより小さいくらいのモデルなら、
-ngl
を多くするとほぼすべて VRAM に乗せられるのでかなり高速に動きます。モデル例
学習やマージ、GGUF を提供されている各氏に感謝します。
command-r
command-r
command-r
mistral-instruct
mistral-instruct
llama-3
llama-3
※Command R+ や LLama 3 は momonga 氏が量子化モデルを多数公開されています。iMatrix に日本語データが利用されているため精度が高いかもしれません。
設定例
※ ざっくり試した範囲のため、この値での動作を保証するものではありません。
--disable_mmap
でメイン RAM の使用量が VRAM の分くらい減らせる感じです。--flash_attn
と--kv_cache_q8
でコンテキスト長を倍くらいにできると思います。(VRAM に乗ってるレイヤー数が少ないと GPU の性能差はあまり出ません。Command-R+ Q4_K_M を、VRAM 24GB + DDR4 メモリ 64GB で動かすのと、VRAM 12GB + DDR5 メモリ 96GB で動かすのは、恐らく同じくらいの速度です。)
-ngl 22 -c 2048
または-ngl 22 -c 4096 --disable_mmap --flash_attn --kv_cache_q8
-ngl 8 -c 4096 --disable_mmap --kv_cache_q8 --flash_attn
-ngl 4
で動くかも。-ngl 24 -c 2048
または-ngl 24 -c 4096 --disable_mmap --flash_attn --kv_cache_q8
-ngl 8 -c 4096 --disable_mmap --kv_cache_q8 --flash_attn
-ngl 4
で動くかも。-ngl 20 -c 2048
-ngl 38 -c 2048
-ngl 22 -c 16384
-ngl 16 -c 2048
--chat_handler mistral-instruct
を追加-ngl 28 -c 2048
-ngl 33 -c 16384
-ngl 33 -c 2048
-ngl 16 -c 16384
--chat_handler mistral-instruct
を追加-ngl 33 -c 16384
-ngl 33 -c 2048
-ngl 33 -c 16384
--chat_handler llama-3
を追加-ngl 44 -c 2048
--chat_handler llama-3
を追加-ngl 33 -c 2048
-ngl 28 -c 2048