-
-
Save ochafik/26e8a31e716d8a474a73ba1f72110c06 to your computer and use it in GitHub Desktop.
Partial json parser w/ healing support
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| static std::optional<json> parse_json(std::string::const_iterator & it, const std::string::const_iterator & end) { | |
| // // https://json.nlohmann.me/features/parsing/sax_interface/ | |
| struct json_error_locator : public nlohmann::json_sax<json> { | |
| std::vector<json> container_stack; | |
| std::size_t position; | |
| bool found_error; | |
| std::string last_token; | |
| std::string exception_message; | |
| std::vector<std::optional<std::string>> name_stack; | |
| std::vector<std::string> closing_stack; | |
| json_error_locator() : position(0), found_error(false) {} | |
| bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT | |
| this->position = position - 1; | |
| this->found_error = true; | |
| this->last_token = last_token; | |
| this->exception_message = ex.what(); | |
| return false; | |
| } | |
| std::optional<std::string> dict_key; | |
| void on_value(const json & value) { | |
| if (container_stack.empty()) { | |
| container_stack.push_back(value); | |
| } else { | |
| if (dict_key) { | |
| container_stack.back()[*dict_key] = value; | |
| } else { | |
| container_stack.back().push_back(value); | |
| } | |
| } | |
| } | |
| bool null() override { | |
| on_value(nullptr); | |
| return true; | |
| } | |
| bool boolean(bool v) override { | |
| on_value(v); | |
| return true; | |
| } | |
| bool number_integer(number_integer_t v) override { | |
| on_value(v); | |
| return true; | |
| } | |
| bool number_unsigned(number_unsigned_t v) override { | |
| on_value(v); | |
| return true; | |
| } | |
| bool number_float(number_float_t v, const string_t &) override { | |
| on_value(v); | |
| return true; | |
| } | |
| bool string(string_t & v) override { | |
| on_value(v); | |
| return true; | |
| } | |
| bool binary(binary_t & v) override { | |
| on_value(v); | |
| return true; | |
| } | |
| bool start_object(std::size_t obj) override | |
| { | |
| dict_key = std::nullopt; | |
| container_stack.push_back(obj); | |
| closing_stack.push_back("}"); | |
| name_stack.emplace_back(std::nullopt); | |
| return true; | |
| } | |
| bool end_object() override { | |
| dict_key = std::nullopt; | |
| container_stack.pop_back(); | |
| GGML_ASSERT(closing_stack.back() == "}"); | |
| closing_stack.pop_back(); | |
| name_stack.pop_back(); | |
| return true; | |
| } | |
| bool key(string_t & key) override { // NOLINT | |
| dict_key = key; | |
| name_stack.back() = key; | |
| return true; | |
| } | |
| bool start_array(std::size_t) override { // NOLINT | |
| dict_key = std::nullopt; | |
| closing_stack.push_back("]"); | |
| name_stack.emplace_back(std::nullopt); | |
| return true; | |
| } | |
| bool end_array() override { | |
| dict_key = std::nullopt; | |
| GGML_ASSERT(closing_stack.back() == "]"); | |
| closing_stack.pop_back(); | |
| name_stack.pop_back(); | |
| return true; | |
| } | |
| }; | |
| json_error_locator err_loc; | |
| json::sax_parse(it, end, &err_loc); | |
| std::string::const_iterator temptative_end; | |
| if (err_loc.found_error) { | |
| std::cerr << "Error at position " << err_loc.position << ":\n"; | |
| std::cerr << " Exception: " << err_loc.exception_message << '\n'; | |
| std::cerr << " Last token: " << err_loc.last_token << '\n'; | |
| std::vector<std::string> closing_stack(err_loc.closing_stack.rbegin(), err_loc.closing_stack.rend()); | |
| std::cerr << " Closing: " << string_join(closing_stack, "") << '\n'; | |
| temptative_end = it + err_loc.position; | |
| } else { | |
| temptative_end = end; | |
| } | |
| std::string json_sub {it, temptative_end}; | |
| try { | |
| auto out = json::parse(json_sub); | |
| it = temptative_end; | |
| return out; | |
| } catch (const std::exception &) { | |
| return std::nullopt; | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| TODO: | |
| - failures should provide the best way to close the json (if a value is needed, w/ null, otherwise closing string / blocks / arrays) | |
| - e.g. close would be: | |
| - `{"a": "b` -> `"}` | |
| - `{"a": -> `null}` | |
| - `{"a": "b` -> `"}` | |
| - `{"a": "b` -> `"}` | |
| - This will allow caller to finish the json w/ magics only in function call arguments | |
| - e.g. | |
| - `{"name": "fn", "parameters": {"code", "print(` | |
| -> close w/ `"}}` | |
| -> caller inserts magic, e.g. `$|$|$` before the closing string | |
| -> caller parses `{"name": "fn", "parameters": {"code", "print($|$|$"}}` | |
| -> arguments are JSON encoded, then truncated at the magic | |
| -> common_chat_tool_call { .name = "fn", .arguments = "{\"code\": \"print(\"}" } | |
| -> This makes arguments properly / diffable & streamable | |
| */ | |
| struct truncated_json_info { | |
| // Flags capture the context of the innermost enclosing array OR object, and of the value we may be in the middle of | |
| // flags: before value, after value, inside string, after string escape | |
| // before dict key/value, after dict key, before dict value, after dict value | |
| // before array value, after array value | |
| // 0 { 1 "...2...\3..." 4 : 5 "...6...\7..." 8 , 1 ... } | |
| // [ 10 ""] | |
| enum location_flags { | |
| TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_IDENT = 1 << 0, | |
| TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING = 1 << 1, | |
| TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING_AFTER_ESCAPE = 1 << 2, | |
| TRUNCATED_JSON_LOCATION_FLAG_DICT_BEFORE_KEY = 1 << 3, | |
| TRUNCATED_JSON_LOCATION_FLAG_DICT_INSIDE_KEY = 1 << 4, | |
| TRUNCATED_JSON_LOCATION_FLAG_DICT_AFTER_KEY = 1 << 5, | |
| TRUNCATED_JSON_LOCATION_FLAG_DICT_BEFORE_VALUE = 1 << 6, | |
| TRUNCATED_JSON_LOCATION_FLAG_DICT_INSIDE_VALUE = 1 << 7, | |
| TRUNCATED_JSON_LOCATION_FLAG_DICT_AFTER_VALUE = 1 << 8, | |
| TRUNCATED_JSON_LOCATION_FLAG_ARRAY_BEFORE_VALUE = 1 << 9, | |
| TRUNCATED_JSON_LOCATION_FLAG_ARRAY_INSIDE_VALUE = 1 << 10, | |
| TRUNCATED_JSON_LOCATION_FLAG_ARRAY_AFTER_VALUE = 1 << 11, | |
| }; | |
| int location_flags; | |
| std::string truncated_source; | |
| std::string nesting_closure; | |
| std::vector<std::optional<std::string>> name_stack; | |
| bool can_heal_with_magic() const { | |
| return location_flags & ( | |
| TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING | | |
| TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING_AFTER_ESCAPE | |
| ); | |
| } | |
| /* | |
| Heals a truncated JSON string with a magic string, returning the healed JSON string and the updated magic string to look for. | |
| This can be used to heal a JSON, transform its values, then serialize them and truncating them at the updated magic string. | |
| (for instance many tool call syntaxes involve expressing function arguments as JSON objects, but are streamed back encoded as partial JSON strings) | |
| TODO: pick magic automagically (increment some random string until it's not in the source, can do in one linear pass | |
| TODO: check that a long json string can be healed from any truncation point (heal then jsonified then truncated at magic should be the same as the original truncation, except for keywords and string escapes) | |
| */ | |
| std::pair<json, std::string> heal_with_magic(const std::string::const_iterator it, const std::string::const_iterator end, const std::string & magic) const { | |
| // if (!(location_flags & (TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING | | |
| // TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING_AFTER_ESCAPE | |
| // ); | |
| std::string src(it, end); | |
| std::string healed_src; | |
| std::string actual_magic; | |
| auto flags = location_flags; | |
| auto move_out = [](int flag) { | |
| if (flag & TRUNCATED_JSON_LOCATION_FLAG_DICT_INSIDE_KEY) { | |
| flag &= ~TRUNCATED_JSON_LOCATION_FLAG_DICT_INSIDE_KEY; | |
| flag |= TRUNCATED_JSON_LOCATION_FLAG_DICT_AFTER_KEY; | |
| } else if (flag & TRUNCATED_JSON_LOCATION_FLAG_DICT_INSIDE_VALUE) { | |
| flag &= ~TRUNCATED_JSON_LOCATION_FLAG_DICT_INSIDE_VALUE; | |
| flag |= TRUNCATED_JSON_LOCATION_FLAG_DICT_AFTER_VALUE; | |
| } else if (flag & TRUNCATED_JSON_LOCATION_FLAG_ARRAY_INSIDE_VALUE) { | |
| flag &= ~TRUNCATED_JSON_LOCATION_FLAG_ARRAY_INSIDE_VALUE; | |
| flag |= TRUNCATED_JSON_LOCATION_FLAG_ARRAY_AFTER_VALUE; | |
| } else { | |
| throw std::runtime_error("Cannot move out of a location that is not inside a key, value or array value"); | |
| } | |
| return flag; | |
| }; | |
| if (flags & TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING) { | |
| healed_src = src + magic + "\""; | |
| actual_magic = magic; | |
| flags &= ~TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING; | |
| flags = move_out(flags); | |
| } else if (flags & TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING_AFTER_ESCAPE) { | |
| GGML_ASSERT(string_ends_with(src, "\\")); | |
| healed_src = src.substr(0, src.size() - 1) + magic + "\""; | |
| actual_magic = magic; | |
| flags &= ~TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_STRING_AFTER_ESCAPE; | |
| flags = move_out(flags); | |
| } else if (flags & TRUNCATED_JSON_LOCATION_FLAG_VALUE_INSIDE_IDENT) { | |
| // TODO: move back out of the identifier, or complete it | |
| throw std::runtime_error("Cannot heal a truncated JSON that stopped inside a keyword / identifier"); | |
| } else { | |
| healed_src = src; | |
| } | |
| if (flags & TRUNCATED_JSON_LOCATION_FLAG_DICT_BEFORE_KEY) { | |
| if (actual_magic.empty()) { | |
| healed_src += "\"" + magic + "\": null"; | |
| actual_magic = "\"" + magic; | |
| } else { | |
| auto str = string_strip(healed_src); | |
| if (str.back() == ',') { | |
| healed_src += " \"\": null"; | |
| } else if (str.back() != '{') { | |
| throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); | |
| } | |
| } | |
| } else if (flags & TRUNCATED_JSON_LOCATION_FLAG_DICT_AFTER_KEY) { | |
| if (actual_magic.empty()) { | |
| healed_src += ": \"" + magic + "\""; | |
| actual_magic = ": \"" + magic; | |
| } else { | |
| healed_src += ": null"; | |
| } | |
| } else if (flags & TRUNCATED_JSON_LOCATION_FLAG_DICT_BEFORE_VALUE) { | |
| if (actual_magic.empty()) { | |
| healed_src += "\"" + magic + "\""; | |
| actual_magic = "\"" + magic; | |
| } else { | |
| healed_src += "null"; | |
| } | |
| } else if (flags & TRUNCATED_JSON_LOCATION_FLAG_DICT_AFTER_VALUE) { | |
| if (actual_magic.empty()) { | |
| healed_src += ", \"" + magic + "\": null"; | |
| actual_magic = ", \"" + magic; | |
| } | |
| } else if (flags & TRUNCATED_JSON_LOCATION_FLAG_ARRAY_BEFORE_VALUE) { | |
| if (actual_magic.empty()) { | |
| healed_src += "\"" + magic + "\""; | |
| actual_magic = "\"" + magic; | |
| } else { | |
| auto str = string_strip(healed_src); | |
| if (str.back() == ',') { | |
| healed_src += "\"\""; | |
| } else if (str.back() != '[') { | |
| throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); | |
| } | |
| } | |
| } else if (flags & TRUNCATED_JSON_LOCATION_FLAG_ARRAY_AFTER_VALUE) { | |
| if (actual_magic.empty()) { | |
| healed_src += ", \"" + magic + "\""; | |
| actual_magic = ", \"" + magic; | |
| } | |
| } | |
| healed_src += nesting_closure; | |
| return {json::parse(healed_src), actual_magic}; | |
| } | |
| }; | |
| static std::optional<json> parse_json(std::string::const_iterator & it, const std::string::const_iterator & end) { | |
| // // https://json.nlohmann.me/features/parsing/sax_interface/ | |
| struct json_error_locator : public nlohmann::json_sax<json> { | |
| std::size_t position; | |
| bool found_error; | |
| std::string last_token; | |
| std::string exception_message; | |
| std::vector<std::optional<std::string>> name_stack; | |
| std::vector<std::string> closing_stack; | |
| json_error_locator() : position(0), found_error(false) {} | |
| bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT | |
| this->position = position - 1; | |
| this->found_error = true; | |
| this->last_token = last_token; | |
| this->exception_message = ex.what(); | |
| return false; | |
| } | |
| bool null() override { return true; } // NOLINT | |
| bool boolean(bool) override { return true; } // NOLINT | |
| bool number_integer(number_integer_t) override { return true; } // NOLINT | |
| bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT | |
| bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT | |
| bool string(string_t &) override { return true; } // NOLINT | |
| bool binary(binary_t &) override { return true; } // NOLINT | |
| bool start_object(std::size_t) override { // NOLINT | |
| closing_stack.push_back("}"); | |
| name_stack.emplace_back(std::nullopt); | |
| return true; | |
| } | |
| bool end_object() override { | |
| GGML_ASSERT(closing_stack.back() == "}"); | |
| closing_stack.pop_back(); | |
| name_stack.pop_back(); | |
| return true; | |
| } | |
| bool key(string_t & key) override { // NOLINT | |
| name_stack.back() = key; | |
| return true; | |
| } | |
| bool start_array(std::size_t) override { // NOLINT | |
| closing_stack.push_back("]"); | |
| name_stack.emplace_back(std::nullopt); | |
| return true; | |
| } | |
| bool end_array() override { | |
| GGML_ASSERT(closing_stack.back() == "]"); | |
| closing_stack.pop_back(); | |
| name_stack.pop_back(); | |
| return true; | |
| } | |
| }; | |
| json_error_locator err_loc; | |
| json::sax_parse(it, end, &err_loc); | |
| std::string::const_iterator temptative_end; | |
| if (err_loc.found_error) { | |
| std::cerr << "Error at position " << err_loc.position << ":\n"; | |
| std::cerr << " Exception: " << err_loc.exception_message << '\n'; | |
| std::cerr << " Last token: " << err_loc.last_token << '\n'; | |
| std::vector<std::string> closing_stack(err_loc.closing_stack.rbegin(), err_loc.closing_stack.rend()); | |
| std::cerr << " Closing: " << string_join(closing_stack, "") << '\n'; | |
| temptative_end = it + err_loc.position; | |
| } else { | |
| temptative_end = end; | |
| } | |
| std::string json_sub {it, temptative_end}; | |
| try { | |
| auto out = json::parse(json_sub); | |
| it = temptative_end; | |
| return out; | |
| } catch (const std::exception &) { | |
| return std::nullopt; | |
| } | |
| } | |
| static void test_json_sax() { | |
| auto parse = [](const std::string & str) { | |
| std::cerr << "# Parsing: " << str << '\n'; | |
| std::string::const_iterator it = str.begin(); | |
| const auto end = str.end(); | |
| return parse_json(it, end); | |
| }; | |
| auto parse_all = [&](const std::string & str) { | |
| for (size_t i = 1; i < str.size() - 1; i++) { | |
| parse(str.substr(0, i)); | |
| } | |
| }; | |
| parse_all("{\"a\": \"b\"}"); | |
| parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}"); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment