-
-
Save ochafik/a3d4a5b9e52390544b205f37fb5a0df3 to your computer and use it in GitHub Desktop.
llama.cpp OpenAI-compatible server w/ tools support for Functionary, Nous-Hermes-2-Pro and any model really
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Annotated | |
| from fastapi import Body | |
| from pydantic import BaseModel | |
| class Item(BaseModel): | |
| name: str | |
| description: str | None = None | |
| price: float | |
| tax: float | None = None | |
| # @app.put("/items/{item_id}") | |
| async def update_item( | |
| item_id: int, | |
| item: Annotated[ | |
| Item, | |
| Body( | |
| examples=[ | |
| { | |
| "name": "Foo", | |
| "description": "A very nice Item", | |
| "price": 35.4, | |
| "tax": 3.2, | |
| } | |
| ], | |
| ), | |
| ], | |
| ): | |
| results = {"item_id": item_id, "item": item} | |
| return results |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import secrets | |
| import string | |
| import sys | |
| from typing import Annotated, Type | |
| import importlib.util | |
| from fastapi import Body, FastAPI | |
| from pydantic import BaseModel | |
| def load_source_as_module(source): | |
| # Adapted from https://medium.com/@david.bonn.2010/dynamic-loading-of-python-code-2617c04e5f3f | |
| def get_sym(): | |
| alphabet = string.ascii_uppercase + string.ascii_lowercase + string.digits | |
| return 'mod_' + ''.join([secrets.choice(alphabet) for _ in range(32)]) | |
| while (module_name := get_sym()) in sys.modules: | |
| pass | |
| spec = importlib.util.spec_from_file_location(module_name, source) | |
| module = importlib.util.module_from_spec(spec) | |
| sys.modules[module_name] = module | |
| spec.loader.exec_module(module) | |
| return module | |
| def bind_functions(app, module): | |
| for k in dir(module): | |
| if k.startswith('_'): | |
| continue | |
| if k == k.capitalize(): | |
| continue | |
| v = getattr(module, k) | |
| if not callable(v) or isinstance(v, Type): | |
| continue | |
| if not hasattr(v, '__annotations__'): | |
| continue | |
| print(f'INFO: Binding /{k}') | |
| app.post(k)(v) | |
| if __name__ == '__main__': | |
| app = FastAPI() | |
| for f in sys.argv[1:]: | |
| if f.endswith('.py'): | |
| module = load_source_as_module(f) | |
| else: | |
| module = importlib.import_module(f) | |
| bind_functions(app, module) | |
| # run app |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/bin/bash | |
| set -euo pipefail | |
| scripts=( "$@" ) | |
| # copy_statements=( | |
| # "COPY requirements.txt" | |
| # "COPY fastify.py" | |
| # ) | |
| # for script in "${scripts[@]}"; do | |
| # if [[ $script ~= .*\.py$ ]]; then | |
| # copy_statements+=( "COPY $script" ) | |
| # fi | |
| # done | |
| # $( printf "%s\n" "${copy_statements[@]}" ) | |
| echo " | |
| FROM python:3.10-slim | |
| COPY requirements.txt /root | |
| COPY fastify.py /root | |
| RUN pip install -r /root/requirements.txt | |
| # MAIN uvicorn fastify:app --reload | |
| CWD /data | |
| MAIN PYTHONPATH=/src python /root/fastify.py $scripts | |
| " | docker build -b - -t llama.cpp/tools-base | |
| docker run -it llama.cpp/tools-base \ | |
| --mount type=bind,source="$SRC_DIR,target=/src,readonly \ | |
| --mount type=bind,source="$DATA_DIR,target=/data | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # pip install "fastapi[all]" "uvicorn[all]" sse-starlette jsonargparse jinja2 pydantic | |
| from enum import StrEnum | |
| import json, sys, subprocess, atexit | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "gguf-py")) | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from json_schema_to_grammar import SchemaConverter | |
| from gguf.gguf_reader import GGUFReader | |
| from gguf.constants import Keys | |
| # from functools import lru_cache | |
| from typing import Annotated, Any, Callable, List, Optional, Set, Tuple, Union | |
| import httpx | |
| from fastapi import Depends, FastAPI, Request, Response | |
| from starlette.responses import StreamingResponse | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Json | |
| # from pydantic_settings import BaseSettings | |
| from jsonargparse import CLI | |
| import jinja2 | |
| def raise_exception(msg: str): | |
| raise Exception(msg) | |
| class ToolStyle(StrEnum): | |
| # https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models | |
| DEFAULT="Default", | |
| # https://github.com/MeetKai/functionary | |
| # TODO: look at https://github.com/ggerganov/llama.cpp/pull/5695 | |
| # https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py | |
| FUNCTIONARY_V2="Functionary V2", | |
| # https://github.com/NousResearch/Hermes-Function-Calling | |
| NOUS_RESEARCH_HERMES="Nous-Research-Hermes-Function-Calling", | |
| def _add_system_prompt(messages: list['Message'], system_prompt: str): | |
| # TODO: add to last system message, or create a new one just before the last user message | |
| system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None) | |
| if system_message is not None: | |
| (i, m) = system_message | |
| messages[i].content = m.content + '\n' + system_prompt | |
| else: | |
| messages.insert(0, Message(role="system", content=system_prompt)) | |
| return messages | |
| class ChatHandler: #(BaseModel): | |
| def __init__(self, template: str, eos_token: str, bos_token: str): | |
| env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True) | |
| self.template = env.from_string(template) | |
| self.eos_token = eos_token | |
| self.bos_token = bos_token | |
| self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template | |
| if "<|recipient|>' + tool_call['function']['name']" in template: | |
| self.tool_style = ToolStyle.FUNCTIONARY_V2 | |
| else: | |
| self.tool_style = ToolStyle.DEFAULT | |
| def __str__(self): | |
| return f"ChatHandler(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})" | |
| @staticmethod | |
| def from_gguf(model: Path): | |
| reader = GGUFReader(model.as_posix()) | |
| return ChatHandler( | |
| template = reader.fields[Keys.Tokenizer.CHAT_TEMPLATE].read(), | |
| bos_token = reader.fields[Keys.Tokenizer.BOS_ID].read(), | |
| eos_token = reader.fields[Keys.Tokenizer.EOS_ID].read()) | |
| def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False): | |
| return self.template.render( | |
| messages=messages, | |
| eos_token=self.eos_token, | |
| bos_token='' if omit_bos else self.bos_token, | |
| raise_exception=raise_exception, | |
| add_generation_prompt=add_generation_prompt, | |
| ) | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ToolFunction(BaseModel): | |
| name: str | |
| description: str | |
| parameters: Any | |
| class Tool(BaseModel): | |
| type: str | |
| function: ToolFunction | |
| class ResponseFormat(BaseModel): | |
| type: str | |
| json_schema: Optional[Any] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| tools: Optional[list[Tool]] = None | |
| messages: list[Message] | |
| response_format: Optional[ResponseFormat] = None | |
| temperature: float = 1.0 | |
| stream: bool = False | |
| class SchemaToTypeScriptConverter: | |
| def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): | |
| return "{" + ', '.join( | |
| f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}' | |
| for prop_name, prop_schema in properties | |
| ) + "}" | |
| def visit(self, schema: dict): | |
| def print_constant(v): | |
| return json.dumps(v) | |
| schema_type = schema.get('type') | |
| schema_format = schema.get('format') | |
| if 'oneOf' in schema or 'anyOf' in schema: | |
| return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf')) | |
| elif isinstance(schema_type, list): | |
| return '|'.join(self.visit({'type': t}) for t in schema_type) | |
| elif 'const' in schema: | |
| return print_constant(schema['const']) | |
| elif 'enum' in schema: | |
| return '|'.join((print_constant(v) for v in schema['enum'])) | |
| elif schema_type in (None, 'object') and \ | |
| ('properties' in schema or \ | |
| ('additionalProperties' in schema and schema['additionalProperties'] is not True)): | |
| required = set(schema.get('required', [])) | |
| properties = list(schema.get('properties', {}).items()) | |
| return self._build_object_rule(properties, required, schema.get('additionalProperties')) | |
| elif schema_type in (None, 'object') and 'allOf' in schema: | |
| required = set() | |
| properties = [] | |
| def add_component(comp_schema, is_required): | |
| if (ref := comp_schema.get('$ref')) is not None: | |
| comp_schema = self._refs[ref] | |
| if 'properties' in comp_schema: | |
| for prop_name, prop_schema in comp_schema['properties'].items(): | |
| properties.append((prop_name, prop_schema)) | |
| if is_required: | |
| required.add(prop_name) | |
| for t in schema['allOf']: | |
| if 'anyOf' in t: | |
| for tt in t['anyOf']: | |
| add_component(tt, is_required=False) | |
| else: | |
| add_component(t, is_required=True) | |
| return self._build_object_rule(properties, required, additional_properties=[]) | |
| elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): | |
| items = schema.get('items') or schema['prefixItems'] | |
| if isinstance(items, list): | |
| return '[' + ', '.join(self.visit(item) for item in items) + '][]' | |
| else: | |
| return self.visit(items) + '[]' | |
| elif schema_type in (None, 'string') and schema_format == 'date-time': | |
| return 'Date' | |
| elif (schema_type == 'object') or (len(schema) == 0): | |
| return 'any' | |
| else: | |
| return 'number' if schema_type == 'integer' else schema_type | |
| def main( | |
| model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"), | |
| host: str = "localhost", | |
| port: int = 8080, | |
| main_server_endpoint: Optional[str] = None, | |
| main_server_host: str = "localhost", | |
| main_server_port: Optional[int] = 8081, | |
| ): | |
| import uvicorn | |
| chat_handler = ChatHandler.from_gguf(model) | |
| print(chat_handler) | |
| if not main_server_endpoint: | |
| server_process = subprocess.Popen([ | |
| "./server", "-m", model, | |
| "--host", main_server_host, "--port", f'{main_server_port}', | |
| ]) | |
| atexit.register(server_process.kill) | |
| main_server_endpoint = f"http://{main_server_host}:{main_server_port}" | |
| app = FastAPI() | |
| @app.post("/v1/chat/completions") | |
| async def chat_completions(request: Request, chat_request: ChatCompletionRequest): | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": request.headers.get("Authorization"), | |
| } | |
| if chat_request.response_format is not None: | |
| assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}" | |
| response_schema = chat_request.response_format.json_schema or {} | |
| else: | |
| response_schema = None | |
| messages = chat_request.messages | |
| parser=None | |
| grammar=None | |
| converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) | |
| response_rule = converter.visit(response_schema, "response") if response_schema else None | |
| delimiter = '<%$[SAMPLE]$%>' | |
| empty_prompt = chat_handler.render([], add_generation_prompt=True) | |
| planted_prompt = chat_handler.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False) | |
| assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}" | |
| [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter) | |
| if chat_request.tools: | |
| if chat_handler.tool_style in (ToolStyle.DEFAULT, ToolStyle.NOUS_RESEARCH_HERMES): | |
| messages = _add_system_prompt(messages, '\n'.join([ | |
| 'Here are the tools available:', | |
| '<tools>', | |
| *(tool.model_dump_json() for tool in chat_request.tools), | |
| '</tools>', | |
| ])) | |
| tool_rules = [ | |
| converter.visit( | |
| dict( | |
| type="object", | |
| properties=dict( | |
| name=dict(const=tool.function.name), | |
| arguments=tool.function.parameters, | |
| ), | |
| required=['name', 'arguments'] | |
| ), | |
| f'{tool.function.name}-tool-call' | |
| ) | |
| for tool in chat_request.tools | |
| ] | |
| # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not) | |
| # OR a tool-call message respecting the schema of any of the tools | |
| converter._add_rule( | |
| "root", | |
| converter._format_literal(prefix) + " (" + | |
| (response_rule or converter.not_literal("<tool_call>")) + " | " + | |
| converter._format_literal("<tool_call>") + " (" + | |
| ' | '.join(tool_rules) + | |
| ") " + converter._format_literal("</tool_call>") + | |
| ") " + converter._format_literal(suffix)) | |
| grammar = converter.format_grammar() | |
| def parse(s: str): | |
| if '<tool_call>'.startswith(s): | |
| if s.startswith('<tool_call>') and s.endswith('</tool_call>'): | |
| s = s[len('<tool_call>'):-len('</tool_call>')] | |
| return {"role": "assistant", "tool_calls": [json.loads(s)]} | |
| return None | |
| else: | |
| return {"role": "assistant", "content": s} | |
| parser = parse | |
| elif chat_handler.tool_style == ToolStyle.FUNCTIONARY_V2: | |
| ts_converter = SchemaToTypeScriptConverter() | |
| messages = _add_system_prompt(messages, '\n'.join([ | |
| '// Supported function definitions that should be called when necessary.' | |
| 'namespace functions {', | |
| *[ | |
| '// ' + tool.function.description.replace('\n', '\n// ') + '\n' + '' | |
| 'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n" | |
| for tool in chat_request.tools | |
| ], | |
| # TODO: typescript conversion of schemas! | |
| # // Get the price of a particular car model | |
| # type get_car_price = (_: { | |
| # // The name of the car model. | |
| # car_name: string, | |
| # }) => any; | |
| # // get the weather of a location | |
| # type get_weather = (_: { | |
| # // where to get weather. | |
| # location: string, | |
| # }) => any; | |
| '} // namespace functions', | |
| ])) | |
| # Only allowing a single tool call at a time for now. | |
| # Note that if there were more, they'd be separated by a '<|from|>assistant' literal | |
| converter._add_rule( | |
| "root", | |
| converter._format_literal(prefix) + " (" + | |
| (response_rule or converter.not_literal("<|recipient|>")) + " | " + | |
| (' | '.join( | |
| converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " + | |
| converter.visit(tool.function.parameters, tool.function.name + '-args') | |
| for tool in chat_request.tools | |
| )) + | |
| ") " + | |
| ") " + converter._format_literal(suffix)) | |
| grammar = converter.format_grammar() | |
| elif response_schema: | |
| converter._add_rule('root', response_rule) | |
| grammar = converter.format_grammar() | |
| if chat_handler.strict_user_assistant_alternation: | |
| print("TODO: merge system messages into user messages") | |
| # new_messages = [] | |
| # TODO: Test whether the template supports formatting tool_calls | |
| prompt = chat_handler.render(messages, add_generation_prompt=True) | |
| print(prompt) | |
| print(dict( | |
| prompt=prompt, | |
| stream=chat_request.stream, | |
| grammar=grammar, | |
| )) | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post( | |
| f"{main_server_endpoint}/completions", | |
| # json=chat_request.model_dump(), | |
| json=dict( | |
| prompt=prompt, | |
| stream=chat_request.stream, | |
| grammar=grammar, | |
| ), | |
| headers=headers, | |
| timeout=None) | |
| return StreamingResponse(generate_chunks(response), media_type="text/event-stream") if chat_request.stream \ | |
| else JSONResponse(response.json()) | |
| async def generate_chunks(response): | |
| async for chunk in response.aiter_bytes(): | |
| yield chunk | |
| uvicorn.run(app, host=host, port=port) | |
| if __name__ == "__main__": | |
| CLI(main) | |
| # @lru_cache | |
| # def get_settings(): | |
| # return Settings(), | |
| # settings: Annotated[Settings, Depends(get_settings)] | |
| # class Settings(BaseSettings): | |
| # # "https://api.openai.com/v1" | |
| # llama_cpp_server_endpoint: str = "http://localhost:8000/v1" | |
| # def render(self, messages: list['Message'], add_generation_prompt: bool, tools: Optional[list['Tool']] = None) -> ChatRendering: | |
| # parser=None | |
| # grammar=None | |
| # converter = SchemaConverter() | |
| # if self.tool_style in (ToolStyle.DEFAULT, ToolStyle.NOUS_RESEARCH_HERMES): | |
| # messages = self._add_system_prompt(messages, '\n'.join([ | |
| # 'Here are the tools available:', | |
| # '<tools>', | |
| # *(tool.model_dump_json() for tool in tools), | |
| # '</tools>', | |
| # ])) | |
| # elif self.tool_style == ToolStyle.FUNCTIONARY_V2: | |
| # messages = self._add_system_prompt(messages, '\n'.join([ | |
| # '// Supported function definitions that should be called when necessary.' | |
| # 'namespace functions {', | |
| # # TODO: typescript conversion of schemas! | |
| # # // Get the price of a particular car model | |
| # # type get_car_price = (_: { | |
| # # // The name of the car model. | |
| # # car_name: string, | |
| # # }) => any; | |
| # # // get the weather of a location | |
| # # type get_weather = (_: { | |
| # # // where to get weather. | |
| # # location: string, | |
| # # }) => any; | |
| # '} // namespace functions', | |
| # ])) | |
| # # <|recipient|>get_car_price | |
| # # <|content|>{"car_name": "Song"} | |
| # sample = "<$$SAMPLE$$>" | |
| # sample_fn = "<$$SAMPLE_FN$$>" | |
| # if self.tool_style == ToolStyle.FUNCTIONARY_V2: | |
| # sample_out = self._raw_render([{ | |
| # "role": "assistant", | |
| # "recipient": sample_fn, | |
| # "content": sample, | |
| # }]) | |
| # else: | |
| # sample_out = self._raw_render([{ | |
| # "role": "assistant", | |
| # "content": sample, | |
| # }]) | |
| # else: | |
| # raise NotImplementedError(f'Unsupported tool_style: {self.tool_style}') | |
| # if self.strict_user_assistant_alternation: | |
| # print("TODO: merge system messages into user messages") | |
| # # new_messages = [] | |
| # prompt = self._raw_render(messages=messages, add_generation_prompt=add_generation_prompt) | |
| # return ChatRendering( | |
| # prompt=prompt, | |
| # grammar=grammar, | |
| # parser=parser, | |
| # ) | |
| # @staticmethod | |
| # def from_gguf(model: Path): | |
| # reader = GGUFReader(model.as_posix()) | |
| # # def get_list(key): | |
| # # field = reader.fields[key] | |
| # # return [field.parts[i] for i in field.data] | |
| # # def get_int(key): | |
| # # field = reader.fields[key] | |
| # # return int(field.parts[field.data[0]]) | |
| # # assert Keys.Tokenizer.CHAT_TEMPLATE in reader.fields, f"Key {Keys.Tokenizer.CHAT_TEMPLATE} not found in GGUF file" | |
| # # token_list = get_list(Keys.Tokenizer.LIST) | |
| # # def get_token(i): | |
| # # return bytes(token_list[i]).decode("utf-8") | |
| # return ChatHandler( | |
| # template = reader.fields[Keys.Tokenizer.CHAT_TEMPLATE].read(), | |
| # bos_token = reader.fields[Keys.Tokenizer.BOS_ID].read(), | |
| # eos_token = reader.fields[Keys.Tokenizer.EOS_ID].read(), | |
| # # template = ''.join(bytes(b).decode("utf-8") for b in get_list(Keys.Tokenizer.CHAT_TEMPLATE)), | |
| # # bos_token = get_token(get_int(Keys.Tokenizer.BOS_ID)), | |
| # # eos_token = get_token(get_int(Keys.Tokenizer.EOS_ID)), | |
| # ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment