Skip to content

Instantly share code, notes, and snippets.

@ochafik
Last active March 25, 2024 19:11
Show Gist options
  • Select an option

  • Save ochafik/a3d4a5b9e52390544b205f37fb5a0df3 to your computer and use it in GitHub Desktop.

Select an option

Save ochafik/a3d4a5b9e52390544b205f37fb5a0df3 to your computer and use it in GitHub Desktop.
llama.cpp OpenAI-compatible server w/ tools support for Functionary, Nous-Hermes-2-Pro and any model really
from typing import Annotated
from fastapi import Body
from pydantic import BaseModel
class Item(BaseModel):
name: str
description: str | None = None
price: float
tax: float | None = None
# @app.put("/items/{item_id}")
async def update_item(
item_id: int,
item: Annotated[
Item,
Body(
examples=[
{
"name": "Foo",
"description": "A very nice Item",
"price": 35.4,
"tax": 3.2,
}
],
),
],
):
results = {"item_id": item_id, "item": item}
return results
import secrets
import string
import sys
from typing import Annotated, Type
import importlib.util
from fastapi import Body, FastAPI
from pydantic import BaseModel
def load_source_as_module(source):
# Adapted from https://medium.com/@david.bonn.2010/dynamic-loading-of-python-code-2617c04e5f3f
def get_sym():
alphabet = string.ascii_uppercase + string.ascii_lowercase + string.digits
return 'mod_' + ''.join([secrets.choice(alphabet) for _ in range(32)])
while (module_name := get_sym()) in sys.modules:
pass
spec = importlib.util.spec_from_file_location(module_name, source)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def bind_functions(app, module):
for k in dir(module):
if k.startswith('_'):
continue
if k == k.capitalize():
continue
v = getattr(module, k)
if not callable(v) or isinstance(v, Type):
continue
if not hasattr(v, '__annotations__'):
continue
print(f'INFO: Binding /{k}')
app.post(k)(v)
if __name__ == '__main__':
app = FastAPI()
for f in sys.argv[1:]:
if f.endswith('.py'):
module = load_source_as_module(f)
else:
module = importlib.import_module(f)
bind_functions(app, module)
# run app
#!/bin/bash
set -euo pipefail
scripts=( "$@" )
# copy_statements=(
# "COPY requirements.txt"
# "COPY fastify.py"
# )
# for script in "${scripts[@]}"; do
# if [[ $script ~= .*\.py$ ]]; then
# copy_statements+=( "COPY $script" )
# fi
# done
# $( printf "%s\n" "${copy_statements[@]}" )
echo "
FROM python:3.10-slim
COPY requirements.txt /root
COPY fastify.py /root
RUN pip install -r /root/requirements.txt
# MAIN uvicorn fastify:app --reload
CWD /data
MAIN PYTHONPATH=/src python /root/fastify.py $scripts
" | docker build -b - -t llama.cpp/tools-base
docker run -it llama.cpp/tools-base \
--mount type=bind,source="$SRC_DIR,target=/src,readonly \
--mount type=bind,source="$DATA_DIR,target=/data
# pip install "fastapi[all]" "uvicorn[all]" sse-starlette jsonargparse jinja2 pydantic
from enum import StrEnum
import json, sys, subprocess, atexit
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "gguf-py"))
sys.path.insert(0, str(Path(__file__).parent.parent))
from json_schema_to_grammar import SchemaConverter
from gguf.gguf_reader import GGUFReader
from gguf.constants import Keys
# from functools import lru_cache
from typing import Annotated, Any, Callable, List, Optional, Set, Tuple, Union
import httpx
from fastapi import Depends, FastAPI, Request, Response
from starlette.responses import StreamingResponse
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Json
# from pydantic_settings import BaseSettings
from jsonargparse import CLI
import jinja2
def raise_exception(msg: str):
raise Exception(msg)
class ToolStyle(StrEnum):
# https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models
DEFAULT="Default",
# https://github.com/MeetKai/functionary
# TODO: look at https://github.com/ggerganov/llama.cpp/pull/5695
# https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py
FUNCTIONARY_V2="Functionary V2",
# https://github.com/NousResearch/Hermes-Function-Calling
NOUS_RESEARCH_HERMES="Nous-Research-Hermes-Function-Calling",
def _add_system_prompt(messages: list['Message'], system_prompt: str):
# TODO: add to last system message, or create a new one just before the last user message
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
if system_message is not None:
(i, m) = system_message
messages[i].content = m.content + '\n' + system_prompt
else:
messages.insert(0, Message(role="system", content=system_prompt))
return messages
class ChatHandler: #(BaseModel):
def __init__(self, template: str, eos_token: str, bos_token: str):
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
self.template = env.from_string(template)
self.eos_token = eos_token
self.bos_token = bos_token
self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
if "<|recipient|>' + tool_call['function']['name']" in template:
self.tool_style = ToolStyle.FUNCTIONARY_V2
else:
self.tool_style = ToolStyle.DEFAULT
def __str__(self):
return f"ChatHandler(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
@staticmethod
def from_gguf(model: Path):
reader = GGUFReader(model.as_posix())
return ChatHandler(
template = reader.fields[Keys.Tokenizer.CHAT_TEMPLATE].read(),
bos_token = reader.fields[Keys.Tokenizer.BOS_ID].read(),
eos_token = reader.fields[Keys.Tokenizer.EOS_ID].read())
def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False):
return self.template.render(
messages=messages,
eos_token=self.eos_token,
bos_token='' if omit_bos else self.bos_token,
raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt,
)
class Message(BaseModel):
role: str
content: str
class ToolFunction(BaseModel):
name: str
description: str
parameters: Any
class Tool(BaseModel):
type: str
function: ToolFunction
class ResponseFormat(BaseModel):
type: str
json_schema: Optional[Any] = None
class ChatCompletionRequest(BaseModel):
model: str
tools: Optional[list[Tool]] = None
messages: list[Message]
response_format: Optional[ResponseFormat] = None
temperature: float = 1.0
stream: bool = False
class SchemaToTypeScriptConverter:
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
return "{" + ', '.join(
f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
for prop_name, prop_schema in properties
) + "}"
def visit(self, schema: dict):
def print_constant(v):
return json.dumps(v)
schema_type = schema.get('type')
schema_format = schema.get('format')
if 'oneOf' in schema or 'anyOf' in schema:
return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf'))
elif isinstance(schema_type, list):
return '|'.join(self.visit({'type': t}) for t in schema_type)
elif 'const' in schema:
return print_constant(schema['const'])
elif 'enum' in schema:
return '|'.join((print_constant(v) for v in schema['enum']))
elif schema_type in (None, 'object') and \
('properties' in schema or \
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
required = set(schema.get('required', []))
properties = list(schema.get('properties', {}).items())
return self._build_object_rule(properties, required, schema.get('additionalProperties'))
elif schema_type in (None, 'object') and 'allOf' in schema:
required = set()
properties = []
def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None:
comp_schema = self._refs[ref]
if 'properties' in comp_schema:
for prop_name, prop_schema in comp_schema['properties'].items():
properties.append((prop_name, prop_schema))
if is_required:
required.add(prop_name)
for t in schema['allOf']:
if 'anyOf' in t:
for tt in t['anyOf']:
add_component(tt, is_required=False)
else:
add_component(t, is_required=True)
return self._build_object_rule(properties, required, additional_properties=[])
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
if isinstance(items, list):
return '[' + ', '.join(self.visit(item) for item in items) + '][]'
else:
return self.visit(items) + '[]'
elif schema_type in (None, 'string') and schema_format == 'date-time':
return 'Date'
elif (schema_type == 'object') or (len(schema) == 0):
return 'any'
else:
return 'number' if schema_type == 'integer' else schema_type
def main(
model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
host: str = "localhost",
port: int = 8080,
main_server_endpoint: Optional[str] = None,
main_server_host: str = "localhost",
main_server_port: Optional[int] = 8081,
):
import uvicorn
chat_handler = ChatHandler.from_gguf(model)
print(chat_handler)
if not main_server_endpoint:
server_process = subprocess.Popen([
"./server", "-m", model,
"--host", main_server_host, "--port", f'{main_server_port}',
])
atexit.register(server_process.kill)
main_server_endpoint = f"http://{main_server_host}:{main_server_port}"
app = FastAPI()
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, chat_request: ChatCompletionRequest):
headers = {
"Content-Type": "application/json",
"Authorization": request.headers.get("Authorization"),
}
if chat_request.response_format is not None:
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
response_schema = chat_request.response_format.json_schema or {}
else:
response_schema = None
messages = chat_request.messages
parser=None
grammar=None
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
response_rule = converter.visit(response_schema, "response") if response_schema else None
delimiter = '<%$[SAMPLE]$%>'
empty_prompt = chat_handler.render([], add_generation_prompt=True)
planted_prompt = chat_handler.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False)
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
if chat_request.tools:
if chat_handler.tool_style in (ToolStyle.DEFAULT, ToolStyle.NOUS_RESEARCH_HERMES):
messages = _add_system_prompt(messages, '\n'.join([
'Here are the tools available:',
'<tools>',
*(tool.model_dump_json() for tool in chat_request.tools),
'</tools>',
]))
tool_rules = [
converter.visit(
dict(
type="object",
properties=dict(
name=dict(const=tool.function.name),
arguments=tool.function.parameters,
),
required=['name', 'arguments']
),
f'{tool.function.name}-tool-call'
)
for tool in chat_request.tools
]
# Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# OR a tool-call message respecting the schema of any of the tools
converter._add_rule(
"root",
converter._format_literal(prefix) + " (" +
(response_rule or converter.not_literal("<tool_call>")) + " | " +
converter._format_literal("<tool_call>") + " (" +
' | '.join(tool_rules) +
") " + converter._format_literal("</tool_call>") +
") " + converter._format_literal(suffix))
grammar = converter.format_grammar()
def parse(s: str):
if '<tool_call>'.startswith(s):
if s.startswith('<tool_call>') and s.endswith('</tool_call>'):
s = s[len('<tool_call>'):-len('</tool_call>')]
return {"role": "assistant", "tool_calls": [json.loads(s)]}
return None
else:
return {"role": "assistant", "content": s}
parser = parse
elif chat_handler.tool_style == ToolStyle.FUNCTIONARY_V2:
ts_converter = SchemaToTypeScriptConverter()
messages = _add_system_prompt(messages, '\n'.join([
'// Supported function definitions that should be called when necessary.'
'namespace functions {',
*[
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
for tool in chat_request.tools
],
# TODO: typescript conversion of schemas!
# // Get the price of a particular car model
# type get_car_price = (_: {
# // The name of the car model.
# car_name: string,
# }) => any;
# // get the weather of a location
# type get_weather = (_: {
# // where to get weather.
# location: string,
# }) => any;
'} // namespace functions',
]))
# Only allowing a single tool call at a time for now.
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
converter._add_rule(
"root",
converter._format_literal(prefix) + " (" +
(response_rule or converter.not_literal("<|recipient|>")) + " | " +
(' | '.join(
converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
converter.visit(tool.function.parameters, tool.function.name + '-args')
for tool in chat_request.tools
)) +
") " +
") " + converter._format_literal(suffix))
grammar = converter.format_grammar()
elif response_schema:
converter._add_rule('root', response_rule)
grammar = converter.format_grammar()
if chat_handler.strict_user_assistant_alternation:
print("TODO: merge system messages into user messages")
# new_messages = []
# TODO: Test whether the template supports formatting tool_calls
prompt = chat_handler.render(messages, add_generation_prompt=True)
print(prompt)
print(dict(
prompt=prompt,
stream=chat_request.stream,
grammar=grammar,
))
async with httpx.AsyncClient() as client:
response = await client.post(
f"{main_server_endpoint}/completions",
# json=chat_request.model_dump(),
json=dict(
prompt=prompt,
stream=chat_request.stream,
grammar=grammar,
),
headers=headers,
timeout=None)
return StreamingResponse(generate_chunks(response), media_type="text/event-stream") if chat_request.stream \
else JSONResponse(response.json())
async def generate_chunks(response):
async for chunk in response.aiter_bytes():
yield chunk
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
CLI(main)
# @lru_cache
# def get_settings():
# return Settings(),
# settings: Annotated[Settings, Depends(get_settings)]
# class Settings(BaseSettings):
# # "https://api.openai.com/v1"
# llama_cpp_server_endpoint: str = "http://localhost:8000/v1"
# def render(self, messages: list['Message'], add_generation_prompt: bool, tools: Optional[list['Tool']] = None) -> ChatRendering:
# parser=None
# grammar=None
# converter = SchemaConverter()
# if self.tool_style in (ToolStyle.DEFAULT, ToolStyle.NOUS_RESEARCH_HERMES):
# messages = self._add_system_prompt(messages, '\n'.join([
# 'Here are the tools available:',
# '<tools>',
# *(tool.model_dump_json() for tool in tools),
# '</tools>',
# ]))
# elif self.tool_style == ToolStyle.FUNCTIONARY_V2:
# messages = self._add_system_prompt(messages, '\n'.join([
# '// Supported function definitions that should be called when necessary.'
# 'namespace functions {',
# # TODO: typescript conversion of schemas!
# # // Get the price of a particular car model
# # type get_car_price = (_: {
# # // The name of the car model.
# # car_name: string,
# # }) => any;
# # // get the weather of a location
# # type get_weather = (_: {
# # // where to get weather.
# # location: string,
# # }) => any;
# '} // namespace functions',
# ]))
# # <|recipient|>get_car_price
# # <|content|>{"car_name": "Song"}
# sample = "<$$SAMPLE$$>"
# sample_fn = "<$$SAMPLE_FN$$>"
# if self.tool_style == ToolStyle.FUNCTIONARY_V2:
# sample_out = self._raw_render([{
# "role": "assistant",
# "recipient": sample_fn,
# "content": sample,
# }])
# else:
# sample_out = self._raw_render([{
# "role": "assistant",
# "content": sample,
# }])
# else:
# raise NotImplementedError(f'Unsupported tool_style: {self.tool_style}')
# if self.strict_user_assistant_alternation:
# print("TODO: merge system messages into user messages")
# # new_messages = []
# prompt = self._raw_render(messages=messages, add_generation_prompt=add_generation_prompt)
# return ChatRendering(
# prompt=prompt,
# grammar=grammar,
# parser=parser,
# )
# @staticmethod
# def from_gguf(model: Path):
# reader = GGUFReader(model.as_posix())
# # def get_list(key):
# # field = reader.fields[key]
# # return [field.parts[i] for i in field.data]
# # def get_int(key):
# # field = reader.fields[key]
# # return int(field.parts[field.data[0]])
# # assert Keys.Tokenizer.CHAT_TEMPLATE in reader.fields, f"Key {Keys.Tokenizer.CHAT_TEMPLATE} not found in GGUF file"
# # token_list = get_list(Keys.Tokenizer.LIST)
# # def get_token(i):
# # return bytes(token_list[i]).decode("utf-8")
# return ChatHandler(
# template = reader.fields[Keys.Tokenizer.CHAT_TEMPLATE].read(),
# bos_token = reader.fields[Keys.Tokenizer.BOS_ID].read(),
# eos_token = reader.fields[Keys.Tokenizer.EOS_ID].read(),
# # template = ''.join(bytes(b).decode("utf-8") for b in get_list(Keys.Tokenizer.CHAT_TEMPLATE)),
# # bos_token = get_token(get_int(Keys.Tokenizer.BOS_ID)),
# # eos_token = get_token(get_int(Keys.Tokenizer.EOS_ID)),
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment