Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created March 25, 2025 12:07
Show Gist options
  • Save CoffeeVampir3/d0e907090aae0af0c6147807885b4fe5 to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/d0e907090aae0af0c6147807885b4fe5 to your computer and use it in GitHub Desktop.
#ifndef RULE_STREAM_HPP
#define RULE_STREAM_HPP
#include <multisampler.hpp>
#include <slot.hpp>
#include "llama.h"
#include "sampling.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include <string>
#include <optional>
#include <functional>
#include <variant>
class TriggerContext;
class RuleStream;
/*
*
* =========================================== ACTIONS ===========================================
*
*/
struct ActionSetGrammar {
std::string grammar;
bool applied = false;
explicit ActionSetGrammar(std::string g) : grammar(std::move(g)) {}
void apply(const llama_model* model, const llama_context*, Slot& slot) {
if (applied) return;
applied = true;
if (slot.multi_sampler.constraint_sampler) {
llama_sampler_free(slot.multi_sampler.constraint_sampler);
}
slot.multi_sampler.constraint_sampler = llama_sampler_init_llg(
llama_model_get_vocab(model),
"lark",
grammar.c_str()
);
}
};
struct ActionClearGrammar {
void apply(const llama_model*, const llama_context*, Slot& slot) {
if (slot.multi_sampler.constraint_sampler) {
llama_sampler_free(slot.multi_sampler.constraint_sampler);
slot.multi_sampler.constraint_sampler = nullptr;
}
}
};
struct ActionEndGeneration {
const std::string stop_reason;
void apply(const llama_model*, const llama_context*, Slot&) {
// Empty
}
};
using Action = std::variant<ActionSetGrammar, ActionClearGrammar, ActionEndGeneration>;
/*
*
* =========================================== TRIGGERS ===========================================
*
*/
class TriggerContext {
public:
llama_token current_token;
explicit TriggerContext(const llama_token token) : current_token(token) {}
};
struct TriggerOnToken {
llama_token token;
explicit TriggerOnToken(const llama_token t) : token(t) {}
bool should_apply(const llama_model*, const llama_context*, Slot&, const TriggerContext& context) const {
return context.current_token == token;
}
};
struct TriggerOnSlotTokensGenerator {
int n_tokens;
explicit TriggerOnSlotTokensGenerator(const int n) : n_tokens(n) {}
bool should_apply(const llama_model*, const llama_context*, const Slot& slot, const TriggerContext&) const {
return slot.tokens_generated >= n_tokens;
}
};
using Trigger = std::variant<TriggerOnToken, TriggerOnSlotTokensGenerator>;
/*
*
* =========================================== ABSTRACT RULE ===========================================
*
*/
struct Rule {
Trigger trigger;
Action action;
Rule(const Trigger t, Action a) : trigger(t), action(std::move(a)) {}
std::optional<std::reference_wrapper<const Action>> execute(const llama_model* model, const llama_context* ctx, Slot& slot, const TriggerContext& context) {
const bool should_apply = std::visit([&](const auto& t) -> bool {
return t.should_apply(model, ctx, slot, context);
}, trigger);
if (should_apply) {
std::visit([&](auto& a) {
a.apply(model, ctx, slot);
}, action);
return {action};
}
return std::nullopt;
}
};
/*
*
* =========================================== RULE STREAM ===========================================
*
*/
class RuleStream {
std::unordered_map<unsigned, std::vector<Rule>> rules_by_id;
unsigned current_id = 0;
public:
unsigned add_rules(std::vector<Rule> rules) {
const unsigned rule_id = current_id++;
rules_by_id[rule_id] = std::move(rules);
return rule_id;
}
void remove_id(const unsigned id) {
rules_by_id.erase(id);
}
const std::vector<Rule>* get_rules(const unsigned id) const {
const auto it = rules_by_id.find(id);
if (it != rules_by_id.end()) {
return &it->second;
}
return nullptr;
}
std::vector<std::reference_wrapper<const Action>> apply_engine(const llama_token token, const llama_model* model, const llama_context* ctx, Slot& slot) {
const TriggerContext context(token);
std::vector<std::reference_wrapper<const Action>> triggered_actions;
for (auto& [id, rule_list] : rules_by_id) {
for (auto& rule : rule_list) {
if (auto result = rule.execute(model, ctx, slot, context)) {
triggered_actions.push_back(result.value());
}
}
}
return triggered_actions;
}
void reset() {
rules_by_id.clear();
current_id = 0;
}
};
/*
*
* =========================================== CONCRETE RULES ===========================================
*
*/
namespace RuleEngine {
inline unsigned rule_max_tokens(
RuleStream& stream,
const int num_tokens
) {
return stream.add_rules({
{TriggerOnSlotTokensGenerator(num_tokens), ActionEndGeneration{"MaxNewTokens"}},
});
}
inline unsigned rule_test(
RuleStream& stream,
const int num_tokens
) {
const auto lark_grammar = R"(
// Define the start rule
start: json_string
// The exact JSON string with fixed format
json_string: "{\n \"action\" : [\"" ACTION_CONTENT "\"],\n \"mood\" : \"" EMOTION "\",\n \"magazine capacity\" : \"" CAPACITY_CONTENT "\"\n}"
// Content restrictions
ACTION_CONTENT: /[a-zA-Z0-9 ,]{1,15}/
CAPACITY_CONTENT: /[0-9]+( rounds| bullets| shots)?/
EMOTION: "happy" | "sad" | "angry" | "excited" | "bored" | "anxious" | "calm" | "confused"
| "curious" | "depressed" | "ecstatic" | "fearful" | "grateful" | "hopeful"
| "irritated" | "jealous" | "peaceful" | "proud" | "surprised" | "tired"
)";
return stream.add_rules({
{TriggerOnSlotTokensGenerator(num_tokens), ActionSetGrammar(lark_grammar)},
});
}
inline unsigned rule_constrain_grammar(
RuleStream& stream,
const std::string& grammar,
const llama_token apply_token,
const llama_token remove_token
) {
return stream.add_rules({
{TriggerOnToken(apply_token), ActionSetGrammar(grammar)},
{TriggerOnToken(remove_token), ActionClearGrammar()}
});
}
}
#endif //RULE_STREAM_HPP
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment