Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Last active September 1, 2024 22:02
Show Gist options
  • Save kohya-ss/e23fa9a321dba07fabd1ef61eab6863c to your computer and use it in GitHub Desktop.
Save kohya-ss/e23fa9a321dba07fabd1ef61eab6863c to your computer and use it in GitHub Desktop.
gradioでLLMを利用する簡易クライアント
# Apache License 2.0
# 使用法は gist のコメントを見てください
import argparse
from typing import List, Optional, Union, Iterator
import llama_cpp
from llama_cpp.llama_chat_format import _convert_completion_to_chat, register_chat_completion_handler
import llama_cpp.llama_types as llama_types
from llama_cpp.llama import LogitsProcessorList, LlamaGrammar
from llama_cpp import Llama, llama_chat_format
import gradio as gr
debug_flag = False
@register_chat_completion_handler("command-r")
def command_r_chat_handler(
llama: Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
max_tokens: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
**kwargs, # type: ignore
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
# bos_token = "<BOS_TOKEN>"
start_turn_token = "<|START_OF_TURN_TOKEN|>"
end_turn_token = "<|END_OF_TURN_TOKEN|>"
user_token = "<|USER_TOKEN|>"
chatbot_token = "<|CHATBOT_TOKEN|>"
system_token = "<|SYSTEM_TOKEN|>"
prompt = "" # bos_token # suppress warning
if len(messages) > 0 and messages[0]["role"] == "system":
prompt += start_turn_token + system_token + messages[0]["content"] + end_turn_token
messages = messages[1:]
for message in messages:
if message["role"] == "user":
prompt += start_turn_token + user_token + message["content"] + end_turn_token
elif message["role"] == "assistant":
prompt += start_turn_token + chatbot_token + message["content"] + end_turn_token
prompt += start_turn_token + chatbot_token
if debug_flag:
print(f"Prompt: {prompt}")
stop_tokens = [end_turn_token] # , bos_token]
return _convert_completion_to_chat(
llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
),
stream=stream,
)
def generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k):
global stop_generating
stop_generating = False
output = prompt
if debug_flag:
print(
f"temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, repeat_penalty: {repeat_penalty}, max_tokens: {max_tokens}"
)
for chunk in llama(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=True,
):
if debug_flag:
print(chunk)
if stop_generating:
break
if "choices" in chunk and len(chunk["choices"]) > 0:
if "text" in chunk["choices"][0]:
text = chunk["choices"][0]["text"]
# check EOS_TOKEN
if text.endswith("<EOS_TOKEN>"): # llama.tokenizer.EOS_TOKEN):
output += text[: -len("<EOS_TOKEN>")]
yield output[len(prompt) :]
break
output += text
yield output[len(prompt) :] # remove prompt
def launch_completion(llama, listen=False):
# css = """
# .prompt textarea {font-size:1.0em !important}
# """
# with gr.Blocks(css=css) as demo:
with gr.Blocks() as demo:
with gr.Row():
# change font size
io_textbox = gr.Textbox(
label="Input/Output: Text may not be scrolled automatically. Shift+Enter to newline. テキストは自動スクロールしないことがあります。Shift+Enterで改行。",
placeholder="Enter your prompt here...",
interactive=True,
elem_classes=["prompt"],
autoscroll=True,
)
with gr.Row():
generate_button = gr.Button("Generate")
stop_button = gr.Button("Stop", visible=False)
with gr.Row():
max_tokens = gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P")
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty")
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K")
def generate_and_display(prompt, max_tokens, temperature, top_p, repeat_penalty, top_k):
output_generator = generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k)
for output in output_generator:
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=False), gr.update(visible=True)
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=True), gr.update(visible=False)
def stop_generation():
globals().update(stop_generating=True)
return gr.update(visible=True), gr.update(visible=False)
generate_button.click(
generate_and_display,
inputs=[io_textbox, max_tokens, temperature, top_p, repeat_penalty, top_k],
outputs=[io_textbox, generate_button, stop_button],
show_progress=True,
)
stop_button.click(
stop_generation,
outputs=[generate_button, stop_button],
)
# add event to textbox to add new line on enter
io_textbox.submit(
lambda x: x + "\n",
inputs=[io_textbox],
outputs=[io_textbox],
)
demo.launch(server_name="0.0.0.0" if listen else None)
def launch_chat(llama, handler_name, listen=False):
def chat(message, history, system_prompt, max_tokens, temperature, top_p, repeat_penalty, top_k):
user_input = message
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for message in history:
messages.append({"role": "user", "content": message[0]})
messages.append({"role": "assistant", "content": message[1]})
messages.append({"role": "user", "content": user_input})
if debug_flag:
print(f"Messages: {messages}")
print(
f"System prompt: {system_prompt}, temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, repeat_penalty: {repeat_penalty}, max_tokens: {max_tokens}"
)
handler = llama_chat_format.get_chat_completion_handler(handler_name)
chat_completion_chunks = handler(
llama=llama,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repeat_penalty=repeat_penalty,
top_k=int(top_k),
stream=True,
)
response = ""
for chunk in chat_completion_chunks:
if debug_flag:
print(chunk)
if "choices" in chunk and len(chunk["choices"]) > 0:
if "delta" in chunk["choices"][0]:
if "content" in chunk["choices"][0]["delta"]:
response += chunk["choices"][0]["delta"]["content"]
yield response
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter system prompt here...")
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P")
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty")
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K")
additional_inputs = [system_prompt, max_tokens, temperature, top_p, repeat_penalty, top_k]
chatbot = gr.ChatInterface(chat, additional_inputs=additional_inputs)
chatbot.launch(server_name="0.0.0.0" if listen else None)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default=None, help="Model file path")
parser.add_argument("-ngl", "--n_gpu_layers", type=int, default=0, help="Number of GPU layers")
parser.add_argument("-c", "--n_ctx", type=int, default=2048, help="Context length")
parser.add_argument(
"-ch",
"--chat_handler",
type=str,
default="command-r",
help="Chat handler, e.g. command-r, mistral-instruct, alpaca, llama-3 etc. default: command-r",
)
parser.add_argument("--chat", action="store_true", help="Chat mode")
parser.add_argument("--listen", action="store_true", help="Listen mode")
parser.add_argument(
"-ts", "--tensor_split", type=str, default=None, help="Tensor split, float values separated by comma for each gpu"
)
parser.add_argument("--disable_mmap", action="store_true", help="Disable mmap")
parser.add_argument("--kv_cache_q8", action="store_true", help="Use quantized kv cache (Q8)")
parser.add_argument("--flash_attn", action="store_true", help="Use flash attention")
parser.add_argument("--debug", action="store_true", help="Debug mode")
args = parser.parse_args()
# tokenizer initialization doesn't seem to be needed
# print("Initializing tokenizer")
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Initializing Llama. Model ID: {args.model}, N_GPU_LAYERS: {args.n_gpu_layers}, N_CTX: {args.n_ctx}")
# llama_tokenizer = LlamaHFTokenizer(tokenizer)
tensor_split = None if args.tensor_split is None else [float(x) for x in args.tensor_split.split(",")]
llama = Llama(
model_path=args.model,
n_gpu_layers=args.n_gpu_layers,
tensor_split=tensor_split,
n_ctx=args.n_ctx,
use_mmap=not args.disable_mmap,
type_k=llama_cpp.GGML_TYPE_Q8_0 if args.kv_cache_q8 else None,
type_v=llama_cpp.GGML_TYPE_Q8_0 if args.kv_cache_q8 else None,
flash_attn=args.flash_attn,
# tokenizer=llama_tokenizer,
# n_threads=n_threads,
)
debug_flag = args.debug
if args.chat:
print(f"Launching chat with handler: {args.chat_handler}")
launch_chat(llama, args.chat_handler, args.listen)
else:
print("Launching completion")
launch_completion(llama, args.listen)
@kohya-ss
Copy link
Author

kohya-ss commented Apr 19, 2024

GGUF 形式のLLMを動かす簡易クライアントです。

最近の更新とか

kv cache の quantize と flash attention に対応しました。最新の llama-cpp-python が必要ですのでインストール済みなら pip install -U llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu124 (末尾は CUDA バージョンによって変更)で更新してください。

mmap を無効にできるようにしました。mmap 有効(デフォルト)だとVRAM だけでなくメイン RAM にもモデルをロードすることで再読み込みが速くなるらしいですが、その分メイン RAM を消費します。mmap 無効でメイン RAM の消費量が減ります。

インストール

  1. venv を作成し、activate します
  2. 依存ライブラリをインストールします
    • llama-cpp-python の extra-index の末尾はインストールされている CUDA バージョンに合わせてください。cu121, cu122, cu123, cu124 が公式リポジトリには書かれています。CUDA のマイナーバージョン(末尾一桁)が違ってもだいたい動くことが多いです。
    • 更新時は pip install -U llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu122 でバージョンアップできます。
pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu122
pip install transformers
pip install gradio
  1. LLM の GGUF ファイルをダウンロードします。コメント後半にいくつか紹介しておきます。
  2. スクリプトを実行します。

スクリプトの使い方

たとえば以下のように指定します。

python gradio_cmdrp.py -m path/to/model-file-00001-of-00002.gguf -ngl 22 -c 2048 --chat

その後、ブラウザで http://localhost:7860/ を開いてください。

-m にダウンロードしたモデルを指定します。複数ファイルある場合は 000001 のファイル名を指定します。

-ngl でいくつのレイヤーを GPU で処理するか指定します。多くすると VRAM を使いますが、それだけ速くなります。また少なくするとメイン RAM の使用量が増え、CPU で処理する量が増えるので遅くなります。

-c でコンテキスト長を指定します。モデルが認識できる入力のトークン数です。長くするとそれだけメモリを使います。

-ngl-c は VRAM とメイン RAM の空きを見ながら決めてください。

--chat でチャットモードで起動します。デフォルトは文章の続きを生成する補完モードです。

--chat_handler でモデルのチャット形式を指定します。command-rmistral-instructllama-3 などが指定できます。省略時は command-r です。llama-cpp-python の chat handler が指定できます。

--listen オプションをつけると LAN 内の他の PC からアクセス可能になります。

--disable_mmap を指定すると mmap を使いません。メイン RAM に余裕がない場合は指定してください。

--kv_cache_q8 で KV cache を quantize することでメモリ使用量を減らします。やや品質が低下するようですが、ほぼ体感できないためメモリに余裕がない場合はお勧めです。

--flash_attn で Flash Attention を有効にします。生成が速くなるらしいです。

--debug でデバッグ出力を表示します。

--help で表示される使用法も参考にしてください。

-ngl について

モデルロード時に llm_load_tensors: offloaded 15/33 layers to GPU のように表示されます。/ のあとの数値がモデルの全レイヤー数ですので、その数値以下の範囲で指定できます。

モデルとダウンロード

各リポジトリには、量子化方法の違いでいくつかのバージョンが用意されています。Q3、Q4 などの数値は量子化ビット数で、K_M や K_S、IQ?_XS などは量子化手法の違いのようです。PPL Value が低いほど性能劣化が抑えられているようです。

000001-of-000002、000001-of-000002 などの連番のファイルをすべて落としてください。

合計ファイルサイズが VRAM サイズより小さいくらいのモデルなら、-ngl を多くするとほぼすべて VRAM に乗せられるのでかなり高速に動きます。

モデル例

学習やマージ、GGUF を提供されている各氏に感謝します。

モデル chat handler
pmysl 氏の Command R+ GGUF command-r
dranger003 氏の Command R+ GGUF command-r
Command-R 35B v1.0 - GGUF command-r
LightChatAssistant-4x7B-GGUF mistral-instruct
Japanese-Starling-ChatV-7B-GGUF mistral-instruct
Meta-Llama-3-70B-Instruct-gguf llama-3
Meta-Llama-3-8B-Instruct-GGUF llama-3

※Command R+ や LLama 3 は momonga 氏が量子化モデルを多数公開されています。iMatrix に日本語データが利用されているため精度が高いかもしれません。

設定例

※ ざっくり試した範囲のため、この値での動作を保証するものではありません。--disable_mmap でメイン RAM の使用量が VRAM の分くらい減らせる感じです。--flash_attn--kv_cache_q8 でコンテキスト長を倍くらいにできると思います。

(VRAM に乗ってるレイヤー数が少ないと GPU の性能差はあまり出ません。Command-R+ Q4_K_M を、VRAM 24GB + DDR4 メモリ 64GB で動かすのと、VRAM 12GB + DDR5 メモリ 96GB で動かすのは、恐らく同じくらいの速度です。)

  • Command-R+ GGUF
    • Q4_K_M, 24GB VRAM/64GB RAM -ngl 22 -c 2048 または -ngl 22 -c 4096 --disable_mmap --flash_attn --kv_cache_q8
    • Q4_K_M, 12GB VRAM/72GB RAM -ngl 8 -c 4096 --disable_mmap --kv_cache_q8 --flash_attn
      • VRAM 8GB なら -ngl 4 で動くかも。
    • IQ4_XS, 24GB VRAM/64GB RAM -ngl 24 -c 2048 または -ngl 24 -c 4096 --disable_mmap --flash_attn --kv_cache_q8
    • IQ4_XS, 12GB VRAM/64GB RAM -ngl 8 -c 4096 --disable_mmap --kv_cache_q8 --flash_attn
      • VRAM 8GB なら -ngl 4 で動くかも。
  • Command-R V01
    • Q8, 24GB VRAM/64GB RAM -ngl 20 -c 2048
    • Q4_K_M, 24GB VRAM/40GB RAM -ngl 38 -c 2048
    • Q4_K_M, 24GB VRAM/48GB RAM -ngl 22 -c 16384
    • Q4_K_M, 12GB VRAM/48GB RAM -ngl 16 -c 2048
  • LightChatAssistant-4x7B : --chat_handler mistral-instruct を追加
    • Q8, 24GB VRAM/40GB RAM -ngl 28 -c 2048
    • Q4, 24GB VRAM/16GB RAM -ngl 33 -c 16384
    • Q4, 16GB VRAM/16GB RAM -ngl 33 -c 2048
    • Q4, 12GB VRAM/24GB RAM -ngl 16 -c 16384
  • Japanese-Starling-ChatV-7B-GGUF : --chat_handler mistral-instruct を追加
    • f16, 24GB VRAM/32GB RAM -ngl 33 -c 16384
    • f16, 16GB VRAM/32GB RAM -ngl 33 -c 2048
    • Q8, 8GB VRAM/24GB RAM -ngl 33 -c 16384
  • Meta-Llama-3-70B-Instruct-gguf: --chat_handler llama-3 を追加
    • Q4_K_M/IQ4_NL, 24GB VRAM/64GB RAM -ngl 44 -c 2048
  • Meta-Llama-3-8B-Instruct-GGUF: --chat_handler llama-3 を追加
    • Q8, 10GB VRAM/24GB RAM -ngl 33 -c 2048
    • Q8, 8GB VRAM/24GB RAM -ngl 28 -c 2048

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment