Created
April 18, 2024 13:09
-
-
Save kohya-ss/37f4c5ef8171cbb2b6cc1f4fd7999b89 to your computer and use it in GitHub Desktop.
llama-cpp-python と gradio で command-r-plus を動かす
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Apache License 2.0 | |
# 使用法は gist のコメントを見てください | |
import argparse | |
from typing import List, Optional, Union, Iterator | |
from llama_cpp import Llama | |
from llama_cpp.llama_tokenizer import LlamaHFTokenizer | |
from llama_cpp.llama_chat_format import _convert_completion_to_chat, register_chat_completion_handler | |
import llama_cpp.llama_types as llama_types | |
from llama_cpp.llama import LogitsProcessorList, LlamaGrammar | |
from transformers import AutoTokenizer | |
import gradio as gr | |
from llama_cpp import Llama | |
import gradio as gr | |
from transformers import AutoTokenizer | |
MODEL_ID = "CohereForAI/c4ai-command-r-plus" | |
MAX_TOKENS_IN_CHAT_MODE = 1024 | |
@register_chat_completion_handler("command-r") | |
def command_r_chat_handler( | |
llama: Llama, | |
messages: List[llama_types.ChatCompletionRequestMessage], | |
functions: Optional[List[llama_types.ChatCompletionFunction]] = None, | |
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, | |
tools: Optional[List[llama_types.ChatCompletionTool]] = None, | |
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, | |
temperature: float = 0.2, | |
top_p: float = 0.95, | |
top_k: int = 40, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
stream: bool = False, | |
stop: Optional[Union[str, List[str]]] = [], | |
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, | |
max_tokens: Optional[int] = None, | |
presence_penalty: float = 0.0, | |
frequency_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
model: Optional[str] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
**kwargs, # type: ignore | |
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: | |
# bos_token = "<BOS_TOKEN>" # not sure if this is needed | |
start_turn_token = "<|START_OF_TURN_TOKEN|>" | |
end_turn_token = "<|END_OF_TURN_TOKEN|>" | |
user_token = "<|USER_TOKEN|>" | |
chatbot_token = "<|CHATBOT_TOKEN|>" | |
# prompt = bos_token + start_turn_token | |
prompt = start_turn_token | |
for message in messages: | |
if message["role"] == "user": | |
prompt += user_token + message["content"] + end_turn_token + start_turn_token | |
elif message["role"] == "assistant": | |
prompt += chatbot_token + message["content"] + end_turn_token + start_turn_token | |
prompt += chatbot_token | |
stop_tokens = [end_turn_token] # , bos_token] | |
return _convert_completion_to_chat( | |
llama.create_completion( | |
prompt=prompt, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
min_p=min_p, | |
typical_p=typical_p, | |
stream=stream, | |
stop=stop_tokens, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
repeat_penalty=repeat_penalty, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
model=model, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
), | |
stream=stream, | |
) | |
def generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
global stop_generating | |
stop_generating = False | |
output = prompt | |
for chunk in llama( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repeat_penalty=repeat_penalty, | |
top_k=top_k, | |
stream=True, | |
): | |
# print(chunk) # uncomment to show each chunk | |
if stop_generating: | |
break | |
if "choices" in chunk and len(chunk["choices"]) > 0: | |
if "text" in chunk["choices"][0]: | |
text = chunk["choices"][0]["text"] | |
# check EOS_TOKEN | |
if text.endswith("<EOS_TOKEN>"): # llama.tokenizer.EOS_TOKEN): | |
output += text[: -len("<EOS_TOKEN>")] | |
yield output[len(prompt) :] | |
break | |
output += text | |
yield output[len(prompt) :] # remove prompt | |
def launch_completion(llama, listen=False): | |
# css = """ | |
# .prompt textarea {font-size:1.0em !important} | |
# """ | |
# with gr.Blocks(css=css) as demo: | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
# change font size | |
io_textbox = gr.Textbox( | |
label="Input/Output", | |
placeholder="Enter your prompt here...", | |
interactive=True, | |
elem_classes=["prompt"], | |
) | |
with gr.Row(): | |
generate_button = gr.Button("Generate") | |
stop_button = gr.Button("Stop", visible=False) | |
with gr.Row(): | |
max_tokens = gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max Tokens") | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P") | |
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty") | |
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K") | |
def generate_and_display(prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
output_generator = generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k) | |
for output in output_generator: | |
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=False), gr.update(visible=True) | |
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=True), gr.update(visible=False) | |
def stop_generation(): | |
globals().update(stop_generating=True) | |
return gr.update(visible=True), gr.update(visible=False) | |
generate_button.click( | |
generate_and_display, | |
inputs=[io_textbox, max_tokens, temperature, top_p, repeat_penalty, top_k], | |
outputs=[io_textbox, generate_button, stop_button], | |
show_progress=True, | |
) | |
stop_button.click( | |
stop_generation, | |
outputs=[generate_button, stop_button], | |
) | |
# add event to textbox to add new line on enter | |
io_textbox.submit( | |
lambda x: x + "\n", | |
inputs=[io_textbox], | |
outputs=[io_textbox], | |
) | |
demo.launch(server_name="0.0.0.0" if listen else None) | |
def launch_chat(llama, listen=False): | |
# GUI for model parameters is not implemented yet in chat mode | |
model_kwargs = { | |
# "temperature": 0.3, | |
# "top_p": 0.95, | |
# "top_k": 40, | |
# "min_p": 0.05, | |
# "typical_p": 1.0, | |
# "stream": False, | |
# "stop": [], | |
"max_tokens": MAX_TOKENS_IN_CHAT_MODE # max tokens for generation | |
} | |
def chat(message, history): | |
user_input = message | |
messages = [] | |
for message in history: | |
messages.append({"role": "user", "content": message[0]}) | |
messages.append({"role": "assistant", "content": message[1]}) | |
messages.append({"role": "user", "content": user_input}) | |
# print("debug: messages", messages) | |
chat_completion_chunks = command_r_chat_handler(llama=llama, messages=messages, stream=True, **model_kwargs) | |
response = "" | |
for chunk in chat_completion_chunks: | |
# print(chunk) # uncomment to show each chunk | |
if "choices" in chunk and len(chunk["choices"]) > 0: | |
if "delta" in chunk["choices"][0]: | |
if "content" in chunk["choices"][0]["delta"]: | |
response += chunk["choices"][0]["delta"]["content"] | |
yield response | |
chatbot = gr.ChatInterface(chat) | |
chatbot.launch(server_name="0.0.0.0" if listen else None) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-m", "--model", type=str, default=None, help="Model file path") | |
parser.add_argument("-ngl", "--n_gpu_layers", type=int, default=0, help="Number of GPU layers") | |
parser.add_argument("-c", "--n_ctx", type=int, default=2048, help="Context length") | |
parser.add_argument("--chat", action="store_true", help="Chat mode") | |
parser.add_argument("--listen", action="store_true", help="Listen mode") | |
args = parser.parse_args() | |
print("Initializing tokenizer") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
print(f"Initializing Llama. Model ID: {args.model}, N_GPU_LAYERS: {args.n_gpu_layers}, N_CTX: {args.n_ctx}") | |
llama_tokenizer = LlamaHFTokenizer(tokenizer) | |
llama = Llama( | |
model_path=args.model, | |
n_gpu_layers=args.n_gpu_layers, | |
# tensor_split=tensor_split, | |
n_ctx=args.n_ctx, | |
tokenizer=llama_tokenizer, | |
# n_threads=n_threads, | |
) | |
print("Launching Gradio") | |
if args.chat: | |
launch_chat(llama, args.listen) | |
else: | |
launch_completion(llama, args.listen) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Command-R v01に対応したバージョンも作りましたので、そちらもご利用ください。
https://gist.github.com/kohya-ss/e23fa9a321dba07fabd1ef61eab6863c