-
-
Save ochafik/a3870e95c3f07d57471b05eaf30917eb to your computer and use it in GitHub Desktop.
Minja normalization logic as template
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| Copyright 2024 Google LLC | |
| Use of this source code is governed by an MIT-style | |
| license that can be found in the LICENSE file or at | |
| https://opensource.org/licenses/MIT. | |
| */ | |
| // SPDX-License-Identifier: MIT | |
| #pragma once | |
| #include "minja.hpp" | |
| #include <json.hpp> | |
| #include <string> | |
| #include <vector> | |
| using json = nlohmann::ordered_json; | |
| namespace minja { | |
| class chat_template { | |
| public: | |
| private: | |
| bool supports_tools_ = true; | |
| // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. | |
| // Most other templates (and OpenAI's API) expect the arguments object to be stringified. | |
| bool requires_object_arguments_ = false; | |
| bool supports_system_role_ = true; | |
| bool supports_parallel_tool_calls_ = false; | |
| bool messages_need_fixes_ = true; | |
| std::string source_; | |
| std::string bos_token_; | |
| std::string eos_token_; | |
| std::shared_ptr<minja::TemplateNode> template_root_; | |
| std::string try_render( | |
| const nlohmann::ordered_json & messages, | |
| const nlohmann::ordered_json & tools, | |
| bool add_generation_prompt, | |
| const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const | |
| { | |
| try { | |
| auto prompt = apply(messages, tools, add_generation_prompt, extra_context); | |
| // fprintf(stderr, "Prompt: %s\n", prompt.c_str()); | |
| return prompt; | |
| } catch (const std::exception & e) { | |
| // fprintf(stderr, "Error: %s\n", e.what()); | |
| return ""; | |
| } | |
| } | |
| public: | |
| chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) | |
| : source_(source), bos_token_(bos_token), eos_token_(eos_token) | |
| { | |
| template_root_ = minja::Parser::parse(source_, { | |
| /* .trim_blocks = */ true, | |
| /* .lstrip_blocks = */ true, | |
| /* .keep_trailing_newline = */ false, | |
| }); | |
| supports_tools_ = source.find("tools") != std::string::npos; | |
| auto renders_string_arguments = | |
| try_render({ | |
| { | |
| {"role", "user"}, | |
| {"content", "Hey"} | |
| }, | |
| { | |
| {"role", "assistant"}, | |
| {"tool_calls", json::array({ | |
| { | |
| {"id", "call_1___"}, | |
| {"type", "function"}, | |
| {"function", { | |
| {"arguments", "{\"code\": \"print('Hello, World!')\"}"}, | |
| {"name", "ipython"}, | |
| }}, | |
| }, | |
| })}, | |
| } | |
| }, {}, false).find("{\"code\": \"print") != std::string::npos; | |
| if (!renders_string_arguments) { | |
| auto renders_object_arguments = | |
| try_render({ | |
| { | |
| {"role", "user"}, | |
| {"content", "Hey"} | |
| }, | |
| { | |
| {"role", "assistant"}, | |
| {"tool_calls", json::array({ | |
| { | |
| {"id", "call_1___"}, | |
| {"type", "function"}, | |
| {"function", { | |
| {"arguments", { | |
| {"code", "print('Hello, World!')"}, | |
| }}, | |
| {"name", "ipython"}, | |
| }}, | |
| }, | |
| })}, | |
| } | |
| }, {}, false).find("{\"code\": \"print") != std::string::npos; | |
| requires_object_arguments_ = renders_object_arguments; | |
| } | |
| supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; | |
| supports_system_role_ = try_render({ | |
| {{"role", "system"}, {"content", "<System Needle>"}}, | |
| {{"role", "user"}, {"content", "Hey"}} | |
| }, {}, false).find("<System Needle>") != std::string::npos; | |
| messages_need_fixes_ = requires_object_arguments_ || !supports_system_role_ || !supports_tools_; | |
| } | |
| const std::string & source() const { return source_; } | |
| bool supports_tools() const { return supports_tools_; } | |
| bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; } | |
| std::string apply( | |
| const nlohmann::ordered_json & messages, | |
| const nlohmann::ordered_json & tools, | |
| bool add_generation_prompt, | |
| const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const | |
| { | |
| json actual_messages; | |
| if (messages_need_fixes_) { | |
| // Fix messages so they have a chance to be rendered correctly by the template | |
| static auto fixer_root = minja::Parser::parse(R"( | |
| {% set pending_system = [] %} | |
| {% set actual_messages = [] %} | |
| {% for message in messages %} | |
| {% if not supports_tools or requires_object_arguments %} | |
| {% for tool_call in message.tool_calls %} | |
| {% if tool_call.type == "function" %} | |
| {% set tool_call.function.arguments = tool_call.function.arguments | from_json %} | |
| {% endif %} | |
| {% endfor %} | |
| {% endif %} | |
| {% if not supports_tools %} | |
| {% if message.tool_calls %} | |
| {% set tool_calls = [] %} | |
| {% for tool_call in message.tool_calls %} | |
| {% if tool_call.type == "function" %} | |
| {% set _ = tool_calls.append({ | |
| "name": tool_call.function.name, | |
| "arguments": tool_call.function.arguments, | |
| }) %} | |
| {% if tool_call.id is not None %} | |
| {% set tool_calls[-1]["id"] = tool_call.id %} | |
| {% endif %} | |
| {% endif %} | |
| {% endfor %} | |
| {% set obj = {"tool_calls": tool_calls} %} | |
| {% if message.content is not None %} | |
| {% set _ = obj.update({"content": message.content}) %} | |
| {% endif %} | |
| {% set message.content = obj | to_json %} | |
| {% set _ = message.pop("tool_calls") %} | |
| {% elif message.role == 'tool' %} | |
| {% set obj = {"tool_response": {"tool": message.name, "content": message.content}} %} | |
| {% if message.tool_call_id is not None %} | |
| {% set obj.tool_response.tool_call_id = message.tool_call_id %} | |
| {% endif %} | |
| {% set message.role = 'user' %} | |
| {% set message.content = obj | to_json %} | |
| {% set _ = message.pop("name") %} | |
| {% endif %} | |
| {% endif %} | |
| {% set skip_message = False %} | |
| {% if message.content is not None and not supports_system_role %} | |
| {% if message.role == 'system' %} | |
| {% set _ = pending_system.append(message.content) %} | |
| {% set skip_message = True %} | |
| {% elif message.role == "user" %} | |
| {% if pending_system %} | |
| {% set message.content = [*pending_system, message.content] | join("\n") %} | |
| {% set pending_system = [] %} | |
| {% endif %} | |
| {% endif %} | |
| {% endif %} | |
| {% if not skip_message %} | |
| {% set _ = actual_messages.append(message) %} | |
| {% endif %} | |
| {% endfor %} | |
| {% if pending_system %} | |
| {% set _ = actual_messages.append({"role": "user", "content": pending_system | join("\n") }) %} | |
| {% endif %} | |
| {% set messages = actual_messages %} | |
| )", {}); | |
| auto fixer_context = minja::Context::make(json({ | |
| {"messages", messages}, | |
| {"requires_object_arguments", requires_object_arguments_}, | |
| {"supports_system_role", supports_system_role_}, | |
| {"supports_tools", supports_tools_}, | |
| })); | |
| fixer_root->render(fixer_context); | |
| actual_messages = fixer_context->get("messages").get<json>(); | |
| } else { | |
| actual_messages = messages; | |
| } | |
| // if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) { | |
| // actual_messages = json::array(); | |
| // std::string pending_system; | |
| // auto flush_sys = [&]() { | |
| // if (!pending_system.empty()) { | |
| // actual_messages.push_back({ | |
| // {"role", "user"}, | |
| // {"content", pending_system}, | |
| // }); | |
| // pending_system.clear(); | |
| // } | |
| // }; | |
| // for (const auto & message_ : messages) { | |
| // auto message = message_; | |
| // if (!message.contains("role") || !message.contains("content")) { | |
| // throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); | |
| // } | |
| // std::string role = message.at("role"); | |
| // if (message.contains("tool_calls")) { | |
| // if (requires_object_arguments_ || !supports_tools_) { | |
| // for (auto & tool_call : message.at("tool_calls")) { | |
| // if (tool_call["type"] == "function") { | |
| // auto & function = tool_call.at("function"); | |
| // std::string arguments = function.at("arguments"); | |
| // function["arguments"] = json::parse(arguments); | |
| // } | |
| // } | |
| // } | |
| // if (!supports_tools_) { | |
| // auto content = message.at("content"); | |
| // auto tool_calls = json::array(); | |
| // for (const auto & tool_call : message.at("tool_calls")) { | |
| // if (tool_call.at("type") != "function") { | |
| // continue; | |
| // } | |
| // const auto & function = tool_call.at("function"); | |
| // auto tc = json { | |
| // {"name", function.at("name")}, | |
| // {"arguments", function.at("arguments")}, | |
| // }; | |
| // if (tool_call.contains("id")) { | |
| // tc["id"] = tool_call["id"]; | |
| // } | |
| // tool_calls.push_back(tc); | |
| // } | |
| // auto obj = json { | |
| // {"tool_calls", tool_calls}, | |
| // }; | |
| // if (!content.is_null() && content != "") { | |
| // obj["content"] = content; | |
| // } | |
| // message["content"] = obj.dump(2); | |
| // message.erase("tool_calls"); | |
| // } | |
| // } | |
| // if (!supports_tools_ && role == "tool") { | |
| // message["role"] = "user"; | |
| // auto obj = json { | |
| // {"tool_response", { | |
| // {"tool", message.at("name")}, | |
| // {"content", message.at("content")}, | |
| // }}, | |
| // }; | |
| // if (message.contains("tool_call_id")) { | |
| // obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); | |
| // } | |
| // message["content"] = obj.dump(2); | |
| // message.erase("name"); | |
| // } | |
| // if (!message["content"].is_null() && !supports_system_role_) { | |
| // std::string content = message.at("content"); | |
| // if (role == "system") { | |
| // if (!pending_system.empty()) pending_system += "\n"; | |
| // pending_system += content; | |
| // continue; | |
| // } else { | |
| // if (role == "user") { | |
| // if (!pending_system.empty()) { | |
| // message["content"] = pending_system + (content.empty() ? "" : "\n" + content); | |
| // pending_system.clear(); | |
| // } | |
| // } else { | |
| // flush_sys(); | |
| // } | |
| // } | |
| // } | |
| // actual_messages.push_back(message); | |
| // } | |
| // flush_sys(); | |
| // } else { | |
| // actual_messages = messages; | |
| // } | |
| auto context = minja::Context::make(json({ | |
| {"messages", actual_messages}, | |
| {"add_generation_prompt", add_generation_prompt}, | |
| {"bos_token", bos_token_}, | |
| {"eos_token", eos_token_}, | |
| })); | |
| if (!tools.is_null()) { | |
| auto tools_val = minja::Value(tools); | |
| context->set("tools", tools_val); | |
| } | |
| if (!extra_context.is_null()) { | |
| for (auto & kv : extra_context.items()) { | |
| minja::Value val(kv.value()); | |
| context->set(kv.key(), val); | |
| } | |
| } | |
| return template_root_->render(context); | |
| } | |
| }; | |
| } // namespace minja |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment