Created
March 25, 2025 12:07
-
-
Save CoffeeVampir3/d0e907090aae0af0c6147807885b4fe5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef RULE_STREAM_HPP | |
#define RULE_STREAM_HPP | |
#include <multisampler.hpp> | |
#include <slot.hpp> | |
#include "llama.h" | |
#include "sampling.h" | |
#include <unordered_map> | |
#include <utility> | |
#include <vector> | |
#include <string> | |
#include <optional> | |
#include <functional> | |
#include <variant> | |
class TriggerContext; | |
class RuleStream; | |
/* | |
* | |
* =========================================== ACTIONS =========================================== | |
* | |
*/ | |
struct ActionSetGrammar { | |
std::string grammar; | |
bool applied = false; | |
explicit ActionSetGrammar(std::string g) : grammar(std::move(g)) {} | |
void apply(const llama_model* model, const llama_context*, Slot& slot) { | |
if (applied) return; | |
applied = true; | |
if (slot.multi_sampler.constraint_sampler) { | |
llama_sampler_free(slot.multi_sampler.constraint_sampler); | |
} | |
slot.multi_sampler.constraint_sampler = llama_sampler_init_llg( | |
llama_model_get_vocab(model), | |
"lark", | |
grammar.c_str() | |
); | |
} | |
}; | |
struct ActionClearGrammar { | |
void apply(const llama_model*, const llama_context*, Slot& slot) { | |
if (slot.multi_sampler.constraint_sampler) { | |
llama_sampler_free(slot.multi_sampler.constraint_sampler); | |
slot.multi_sampler.constraint_sampler = nullptr; | |
} | |
} | |
}; | |
struct ActionEndGeneration { | |
const std::string stop_reason; | |
void apply(const llama_model*, const llama_context*, Slot&) { | |
// Empty | |
} | |
}; | |
using Action = std::variant<ActionSetGrammar, ActionClearGrammar, ActionEndGeneration>; | |
/* | |
* | |
* =========================================== TRIGGERS =========================================== | |
* | |
*/ | |
class TriggerContext { | |
public: | |
llama_token current_token; | |
explicit TriggerContext(const llama_token token) : current_token(token) {} | |
}; | |
struct TriggerOnToken { | |
llama_token token; | |
explicit TriggerOnToken(const llama_token t) : token(t) {} | |
bool should_apply(const llama_model*, const llama_context*, Slot&, const TriggerContext& context) const { | |
return context.current_token == token; | |
} | |
}; | |
struct TriggerOnSlotTokensGenerator { | |
int n_tokens; | |
explicit TriggerOnSlotTokensGenerator(const int n) : n_tokens(n) {} | |
bool should_apply(const llama_model*, const llama_context*, const Slot& slot, const TriggerContext&) const { | |
return slot.tokens_generated >= n_tokens; | |
} | |
}; | |
using Trigger = std::variant<TriggerOnToken, TriggerOnSlotTokensGenerator>; | |
/* | |
* | |
* =========================================== ABSTRACT RULE =========================================== | |
* | |
*/ | |
struct Rule { | |
Trigger trigger; | |
Action action; | |
Rule(const Trigger t, Action a) : trigger(t), action(std::move(a)) {} | |
std::optional<std::reference_wrapper<const Action>> execute(const llama_model* model, const llama_context* ctx, Slot& slot, const TriggerContext& context) { | |
const bool should_apply = std::visit([&](const auto& t) -> bool { | |
return t.should_apply(model, ctx, slot, context); | |
}, trigger); | |
if (should_apply) { | |
std::visit([&](auto& a) { | |
a.apply(model, ctx, slot); | |
}, action); | |
return {action}; | |
} | |
return std::nullopt; | |
} | |
}; | |
/* | |
* | |
* =========================================== RULE STREAM =========================================== | |
* | |
*/ | |
class RuleStream { | |
std::unordered_map<unsigned, std::vector<Rule>> rules_by_id; | |
unsigned current_id = 0; | |
public: | |
unsigned add_rules(std::vector<Rule> rules) { | |
const unsigned rule_id = current_id++; | |
rules_by_id[rule_id] = std::move(rules); | |
return rule_id; | |
} | |
void remove_id(const unsigned id) { | |
rules_by_id.erase(id); | |
} | |
const std::vector<Rule>* get_rules(const unsigned id) const { | |
const auto it = rules_by_id.find(id); | |
if (it != rules_by_id.end()) { | |
return &it->second; | |
} | |
return nullptr; | |
} | |
std::vector<std::reference_wrapper<const Action>> apply_engine(const llama_token token, const llama_model* model, const llama_context* ctx, Slot& slot) { | |
const TriggerContext context(token); | |
std::vector<std::reference_wrapper<const Action>> triggered_actions; | |
for (auto& [id, rule_list] : rules_by_id) { | |
for (auto& rule : rule_list) { | |
if (auto result = rule.execute(model, ctx, slot, context)) { | |
triggered_actions.push_back(result.value()); | |
} | |
} | |
} | |
return triggered_actions; | |
} | |
void reset() { | |
rules_by_id.clear(); | |
current_id = 0; | |
} | |
}; | |
/* | |
* | |
* =========================================== CONCRETE RULES =========================================== | |
* | |
*/ | |
namespace RuleEngine { | |
inline unsigned rule_max_tokens( | |
RuleStream& stream, | |
const int num_tokens | |
) { | |
return stream.add_rules({ | |
{TriggerOnSlotTokensGenerator(num_tokens), ActionEndGeneration{"MaxNewTokens"}}, | |
}); | |
} | |
inline unsigned rule_test( | |
RuleStream& stream, | |
const int num_tokens | |
) { | |
const auto lark_grammar = R"( | |
// Define the start rule | |
start: json_string | |
// The exact JSON string with fixed format | |
json_string: "{\n \"action\" : [\"" ACTION_CONTENT "\"],\n \"mood\" : \"" EMOTION "\",\n \"magazine capacity\" : \"" CAPACITY_CONTENT "\"\n}" | |
// Content restrictions | |
ACTION_CONTENT: /[a-zA-Z0-9 ,]{1,15}/ | |
CAPACITY_CONTENT: /[0-9]+( rounds| bullets| shots)?/ | |
EMOTION: "happy" | "sad" | "angry" | "excited" | "bored" | "anxious" | "calm" | "confused" | |
| "curious" | "depressed" | "ecstatic" | "fearful" | "grateful" | "hopeful" | |
| "irritated" | "jealous" | "peaceful" | "proud" | "surprised" | "tired" | |
)"; | |
return stream.add_rules({ | |
{TriggerOnSlotTokensGenerator(num_tokens), ActionSetGrammar(lark_grammar)}, | |
}); | |
} | |
inline unsigned rule_constrain_grammar( | |
RuleStream& stream, | |
const std::string& grammar, | |
const llama_token apply_token, | |
const llama_token remove_token | |
) { | |
return stream.add_rules({ | |
{TriggerOnToken(apply_token), ActionSetGrammar(grammar)}, | |
{TriggerOnToken(remove_token), ActionClearGrammar()} | |
}); | |
} | |
} | |
#endif //RULE_STREAM_HPP |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment