Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created April 18, 2024 13:09
Show Gist options
  • Save kohya-ss/37f4c5ef8171cbb2b6cc1f4fd7999b89 to your computer and use it in GitHub Desktop.
Save kohya-ss/37f4c5ef8171cbb2b6cc1f4fd7999b89 to your computer and use it in GitHub Desktop.
llama-cpp-python と gradio で command-r-plus を動かす
# Apache License 2.0
# 使用法は gist のコメントを見てください
import argparse
from typing import List, Optional, Union, Iterator
from llama_cpp import Llama
from llama_cpp.llama_tokenizer import LlamaHFTokenizer
from llama_cpp.llama_chat_format import _convert_completion_to_chat, register_chat_completion_handler
import llama_cpp.llama_types as llama_types
from llama_cpp.llama import LogitsProcessorList, LlamaGrammar
from transformers import AutoTokenizer
import gradio as gr
from llama_cpp import Llama
import gradio as gr
from transformers import AutoTokenizer
MODEL_ID = "CohereForAI/c4ai-command-r-plus"
MAX_TOKENS_IN_CHAT_MODE = 1024
@register_chat_completion_handler("command-r")
def command_r_chat_handler(
llama: Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
max_tokens: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
**kwargs, # type: ignore
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
# bos_token = "<BOS_TOKEN>" # not sure if this is needed
start_turn_token = "<|START_OF_TURN_TOKEN|>"
end_turn_token = "<|END_OF_TURN_TOKEN|>"
user_token = "<|USER_TOKEN|>"
chatbot_token = "<|CHATBOT_TOKEN|>"
# prompt = bos_token + start_turn_token
prompt = start_turn_token
for message in messages:
if message["role"] == "user":
prompt += user_token + message["content"] + end_turn_token + start_turn_token
elif message["role"] == "assistant":
prompt += chatbot_token + message["content"] + end_turn_token + start_turn_token
prompt += chatbot_token
stop_tokens = [end_turn_token] # , bos_token]
return _convert_completion_to_chat(
llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
),
stream=stream,
)
def generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k):
global stop_generating
stop_generating = False
output = prompt
for chunk in llama(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=True,
):
# print(chunk) # uncomment to show each chunk
if stop_generating:
break
if "choices" in chunk and len(chunk["choices"]) > 0:
if "text" in chunk["choices"][0]:
text = chunk["choices"][0]["text"]
# check EOS_TOKEN
if text.endswith("<EOS_TOKEN>"): # llama.tokenizer.EOS_TOKEN):
output += text[: -len("<EOS_TOKEN>")]
yield output[len(prompt) :]
break
output += text
yield output[len(prompt) :] # remove prompt
def launch_completion(llama, listen=False):
# css = """
# .prompt textarea {font-size:1.0em !important}
# """
# with gr.Blocks(css=css) as demo:
with gr.Blocks() as demo:
with gr.Row():
# change font size
io_textbox = gr.Textbox(
label="Input/Output",
placeholder="Enter your prompt here...",
interactive=True,
elem_classes=["prompt"],
)
with gr.Row():
generate_button = gr.Button("Generate")
stop_button = gr.Button("Stop", visible=False)
with gr.Row():
max_tokens = gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P")
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty")
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K")
def generate_and_display(prompt, max_tokens, temperature, top_p, repeat_penalty, top_k):
output_generator = generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k)
for output in output_generator:
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=False), gr.update(visible=True)
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=True), gr.update(visible=False)
def stop_generation():
globals().update(stop_generating=True)
return gr.update(visible=True), gr.update(visible=False)
generate_button.click(
generate_and_display,
inputs=[io_textbox, max_tokens, temperature, top_p, repeat_penalty, top_k],
outputs=[io_textbox, generate_button, stop_button],
show_progress=True,
)
stop_button.click(
stop_generation,
outputs=[generate_button, stop_button],
)
# add event to textbox to add new line on enter
io_textbox.submit(
lambda x: x + "\n",
inputs=[io_textbox],
outputs=[io_textbox],
)
demo.launch(server_name="0.0.0.0" if listen else None)
def launch_chat(llama, listen=False):
# GUI for model parameters is not implemented yet in chat mode
model_kwargs = {
# "temperature": 0.3,
# "top_p": 0.95,
# "top_k": 40,
# "min_p": 0.05,
# "typical_p": 1.0,
# "stream": False,
# "stop": [],
"max_tokens": MAX_TOKENS_IN_CHAT_MODE # max tokens for generation
}
def chat(message, history):
user_input = message
messages = []
for message in history:
messages.append({"role": "user", "content": message[0]})
messages.append({"role": "assistant", "content": message[1]})
messages.append({"role": "user", "content": user_input})
# print("debug: messages", messages)
chat_completion_chunks = command_r_chat_handler(llama=llama, messages=messages, stream=True, **model_kwargs)
response = ""
for chunk in chat_completion_chunks:
# print(chunk) # uncomment to show each chunk
if "choices" in chunk and len(chunk["choices"]) > 0:
if "delta" in chunk["choices"][0]:
if "content" in chunk["choices"][0]["delta"]:
response += chunk["choices"][0]["delta"]["content"]
yield response
chatbot = gr.ChatInterface(chat)
chatbot.launch(server_name="0.0.0.0" if listen else None)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default=None, help="Model file path")
parser.add_argument("-ngl", "--n_gpu_layers", type=int, default=0, help="Number of GPU layers")
parser.add_argument("-c", "--n_ctx", type=int, default=2048, help="Context length")
parser.add_argument("--chat", action="store_true", help="Chat mode")
parser.add_argument("--listen", action="store_true", help="Listen mode")
args = parser.parse_args()
print("Initializing tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Initializing Llama. Model ID: {args.model}, N_GPU_LAYERS: {args.n_gpu_layers}, N_CTX: {args.n_ctx}")
llama_tokenizer = LlamaHFTokenizer(tokenizer)
llama = Llama(
model_path=args.model,
n_gpu_layers=args.n_gpu_layers,
# tensor_split=tensor_split,
n_ctx=args.n_ctx,
tokenizer=llama_tokenizer,
# n_threads=n_threads,
)
print("Launching Gradio")
if args.chat:
launch_chat(llama, args.listen)
else:
launch_completion(llama, args.listen)
@kohya-ss
Copy link
Author

kohya-ss commented Apr 18, 2024

使い方

VRAM 24GB とメイン RAM 64GB で Command-R+ Q4 を動かす例です。

Command-R+ の Q3 と Q4 の間には少なくない性能差があるようなので Q4 を動かします。

  1. venv を作成し、activate します
  2. 依存ライブラリをインストールします
    • llama-cpp-python の extra-index-url の末尾はインストールされている CUDA バージョンに合わせてください。cu121, cu122, cu123 が公式リポジトリには書かれています(ただ、このコメント執筆時点では cu123 は存在しませんでした)。CUDA のマイナーバージョン(末尾一桁)が違ってもだいたい動くことが多いです。
    • ※最新の CUDA 版の wheel が存在せず、CPU 版がインストールされてしまうことがあるようです。https://abetlen.github.io/llama-cpp-python/whl/cu122 にアクセスして、存在するバージョンを llama-cpp-python==0.2.62 のように指定してください。
pip install llama-cpp-python==0.2.62 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu122
pip install transformers
pip install gradio
  1. Command R+ の GGUF ファイルをダウンロードします。
  2. スクリプトを実行します。
python gradio_cmdrp.py --model path/to/model-file-00001-of-00002.gguf -ngl 22 -c 2048 --chat

その後、ブラウザで http://localhost:7860/ を開いてください。

-ngl 22 -c 2048 で、Q4_K_M がギリギリ動くと思います。IQ4_XS なら -ngl 24 -c 2048 でメイン RAM の使用量を減らすと良さそうです。。
-c 2048 はコンテキスト長 2048 で、2048 トークンまで入力を受け付けます。このくらいあればまあまあの分量、応答できると思います。

オプション

-ngl でいくつのレイヤーを GPU で処理するか指定します。多くすると VRAM を使いますが、それだけ速くなります。また少なくするとメイン RAM の使用量が増え、CPU で処理する量が増えるので遅くなります。
-c でコンテキスト長を指定します。長くするとそれだけメモリを使います。
--chat でチャットモードで起動します。デフォルトは文章の続きを生成する補完モードです。
--listen オプションをつけると LAN 内の他の PC からアクセス可能になります。

その他

Q3、Q4 などの数値は量子化ビット数で、K_M や K_S、IQ?_XS などは量子化手法の違いのようです。PPL Value が低いほど性能劣化が抑えられているようです。

RAM/VRAM とも余裕がなければ量子化ビット数の少ないモデル(サイズの小さいモデル)選んでください。

@kohya-ss
Copy link
Author

Command-R v01に対応したバージョンも作りましたので、そちらもご利用ください。
https://gist.github.com/kohya-ss/e23fa9a321dba07fabd1ef61eab6863c

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment