Skip to content

Instantly share code, notes, and snippets.

@ochafik
Created April 23, 2024 12:55
Show Gist options
  • Select an option

  • Save ochafik/b6d2db998ec6db3db58f1cc8a968fdc3 to your computer and use it in GitHub Desktop.

Select an option

Save ochafik/b6d2db998ec6db3db58f1cc8a968fdc3 to your computer and use it in GitHub Desktop.
llama.cpp: detect left recursion in grammars
https://github.com/ggerganov/llama.cpp/issues/6492
parse_state parse(const char * src) {
...
// Detect left recursion.
std::unordered_set<const llama_grammar_element *> tested_rules;
std::function<void(const std::vector<llama_grammar_element> &, std::vector<llama_grammar_element *> &)> detect_left_recursion =
[&](const std::vector<llama_grammar_element> & rule, std::vector<llama_grammar_element *> & stack) {
auto elem = rule.data();
if (tested_rules.find(elem) != tested_rules.end()) {
return;
}
if (std::find(stack.begin(), stack.end(), elem) != stack.end()) {
std::vector<std::string> rule_names(state.rules.size());
for (const auto & kv : state.symbol_ids) {
rule_names[kv.second] = kv.first;
}
std::ostringstream out;
out << "Left recursion detected: ";
for (size_t i = 0; i < stack.size(); i++) {
if (i > 0) {
out << " -> ";
}
GGML_ASSERT(stack[i]->type == LLAMA_GRETYPE_RULE_REF);
out << rule_names[stack[i]->value];
}
throw std::runtime_error(out.str());
}
stack.push_back(const_cast<llama_grammar_element *>(elem));
foreach_alt(elem, [&](const llama_grammar_element * alt) {
if (alt->type == LLAMA_GRETYPE_RULE_REF) {
detect_left_recursion(state.rules[alt->value], stack);
}
});
stack.pop_back();
tested_rules.insert(elem);
};
std::vector<llama_grammar_element *> stack;
for (const auto & rule : state.rules) {
detect_left_recursion(rule, stack);
GGML_ASSERT(stack.empty());
}
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment