- I want to write a domain-specific language for producing LLM constraints, to use specifically for code generation.
- My main motivation at the moment is to write a framework for constrained generation that operates at the semantic level.
- I am curious to see how much of a deep integration we can write between a programming language's type system.
This is an example program for constraining the LLM output to be one of the following alternatives:
<div>Hi!</div>
<div>{inventory.fruit[0]}</div>;
<div>{inventory.fruit[1]}</div>;
<div>{inventory.fruit[2]}</div>;
I've been thinking about this problem for a few days. Now, I am just exploring how to develop this with the cheapest LLM I can find (that performs well enough to run with, of course) to learn more about the fascinating inner-workings of LLMs.
I chose Mistral 7B Instruct v0.1, running inference in my Mac with MLX using their very cool mlx-examples
repo.
T(
Statement(
OneOf(
JSXElement(
"div",
OneOf(
JSXText("Hi!"),
JSXExpression(
MembersOf(
"inventory",
Property(
"fruit",
ComputedProperty("1"),
ComputedProperty("2"),
ComputedProperty("3"),
),
),
),
),
),
)
),
close=tokenizer.eos_token,
)
The output of the outer-most function T
is a prefix tree (or TokenTrie
).
I chose it because it helps me efficiently store and retrieve strings that have common prefixes. To me, this seems like a no-brainer. With LLMs, text is generated from start to finish after all!
I decided not to care too much about looking backwards yet, but I immediately stumbled upon the need to append text after closing a specific semantic tree in the TokenTrie
.
I decided to handle closing tokens using a stack that gets maintained in a whole different semantic layer. The operations in this stack are encoded into the TokenTrie
using command names @push
and @pop
.
{
"": {
"": {
"": {
"<div>": {
"": {
"Hi!": {
"@pop": {}
},
"{": {
"inventory": {
".fruit": {
"[1": {
"]": {
"@pop": {}
}
},
"[2": {
"]": {
"@pop": {}
}
},
"[3": {
"]": {
"@pop": {}
}
},
"@push": {
"@pop": {}
}
},
"@push": {
"@pop": {}
}
},
"@push": {
"}": {
"@pop": {}
}
}
},
"@push": {
"@pop": {}
}
},
"@push": {
"</div>": {
"@pop": {}
}
}
},
"@push": {
"@pop": {}
}
},
"@push": {
";": {
"@pop": {}
}
}
},
"@push": {
"</s>": {
"@pop": {}
}
}
}
}
That is... very hard to read. Let's remove the noise of all those extra layers so that it's easier on the eyes.
Unfortunately, you'll see a flaw in my design for closing mechanics: these keys may collide! I could define a program that outputs TokenTrie('@push')
and it would override the behavior of @push
, NO! I'll show you how I tackled this later.
For now, let's focus on the tree, here's a nicer version to look at after removing uninteresting stuff like prefix-less TokenTrie
s.
{
"<div>": {
"Hi!": {
"@pop": {}
},
"{": {
"inventory": {
".fruit": {
"[1": { "]": { "@pop": {} } },
"[2": { "]": { "@pop": {} } },
"[3": { "]": { "@pop": {} } },
"@push": { "@pop": {} }
},
"@push": { "@pop": {} }
},
"@push": { "}": { "@pop": {} } }
},
"@push": { "@pop": {} }
},
"@push": { "</div>": { "@pop": {} } }
}
I got those TokenTrie
s by implementing T
, the simplest primitive of my toy DSL.
from typing import Callable, Optional, Union
END = {"@pop": {}} # - The token decision tree assembly process
# uses these @pop and @push commands.
# - These allow me to maintain a stack of
# tokens that will be appended to the trie
# when we reach the leaf nodes.
TokenTrie = dict[Union[str, int], Optional["TokenTrie"]]
PrimitiveOrTokenTrie = Union[str, int, TokenTrie]
CallableClosing = lambda node: END
Closing = Union[CallableClosing, PrimitiveOrTokenTrie]
def Token(value: PrimitiveOrTokenTrie) -> TokenTrie:
"""
A utility function.
Produces a new `TokenTrie` from a `str`, or forwards a `TokenTrie`
"""
return T(value) if isinstance(value, str) else value
def ClosingTrie(close: Closing, *args: list) -> TokenTrie:
"""
Trait for producing closing mechanics.
Produces a new `TokenTrie` from a `str`, or forwards a `TokenTrie`.
"""
return close(*args) if callable(close) else Token(close)
def T(*next: TokenTrie, close: Closing = lambda _: END) -> TokenTrie:
"""
Produces a prefix tree or trie data structure.
It is the most primitive data structure in the DSL to represent a branch in a decision tree.
It should have at least a head, which is a string key, and the next branches.
Through the `close` keyword argument, it is possible to define behaviors when the branch ends.
"""
key = next[0]
if isinstance(key, int):
return {key: ClosingTrie(close, next)}
if not isinstance(key, str):
return T("", *next, close=close)
if len(next) == 1:
return {key: ClosingTrie(close, next)}
result = {key: value for dict in next[1:] for key, value in Token(dict).items()}
if close is not None:
result["@push"] = ClosingTrie(close, next)
result = {key: result}
return result
Those are the two higher-level semantic primitives that I'm working with for now.
T
->TokenTrie["*"]
prefix tree with arbitrary closing mechanics- And
TokenTrie["@*"]
a data structure for holding abritrary properties for the compiler to use to traverse the trie.
I finally had my TokenTrie
that represents the code I want to generate. Now it's time to compile it into an even lower-level primitive that is only internal to the inference system called a Trie
.
Now, what we want to achieve is a most mechanical data structure, this is the Trie
.
- Every string is broken down into tokens, and tokens mapped to token ids.
- I construct a
Trie
that we will be able to traverse at inference time. - Right now all of the branches are terminal, but having "generative segments" may be possible through more closing stacks and special commands
{
"523": {
"1538": {
"28767": {
"15359": {
"28808": {
"1867": {
"1538": {
"28767": {
"2753": {
"2": {}
}
}
}
}
}
},
"371": {
"19930": {
"842": {
"28722": {
"6424": {
"733": {
"28740": {
"4709": {
"443": {
"1867": {
"1538": {
"28767": {
"2753": {
"2": {}
}
}
}
}
}
}
},
"28750": {
"4709": {
"443": {
"1867": {
"1538": {
"28767": {
"2753": {
"2": {}
}
}
}
}
}
}
},
"28770": {
"4709": {
"443": {
"1867": {
"1538": {
"28767": {
"2753": {
"2": {}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
All those integers mean very little to me at a glance, those are token IDs for our tokenizer. In this case, I'm using Mistral, which uses SentencePiece
(32000tk) to encode/decode strings, so this would only work for models that use that.
This is because the compiler expects a callable tokenize
function, which means we are tokenizer agnostic!
Here's my compile
function, that takes the TokenTrie
and turns into a Trie
.
from typing import Optional
Trie = dict[int, Optional['Trie']]
Sorry, I had to do it.
def compile(
tree: TokenTrie,
tokenize: Callable[[str], list[int]],
end: list[TokenTrie] = [],
) -> Trie:
"""
Tokenize the keys of a tree recursively.
"""
new_tree: Trie = {}
if "@push" in tree:
if isinstance(tree["@push"], dict) and "@pop" not in tree["@push"]:
end.append(tree["@push"])
for key, value in tree.items():
if key == "@push":
continue
current_dict: Trie = new_tree
if key == "@pop":
if len(end) > 0:
current_dict.update(compile(end[-1], tokenize, end[:-1]))
continue
tokens = tokenize(key) if isinstance(key, str) else [key]
for token in tokens:
if token not in current_dict:
current_dict[token] = {}
current_dict = current_dict[token]
current_dict.update(compile(value, tokenize, end))
return new_tree
- In a real-world use-case the
TokenTrie
may need to be compiled in an isolated process. - This is why I have two data structures
TokenTrie
andTrie
. The output ofT
is for producing the raw tree, the output ofcompile
is the temporarily final target (in semantics, nothing is final).
- I wanted to run this code as close as possible to the part where
mlx_lm.generate
samples the next token, I wanted to be able to see the results of my updates as fast as possible. - The code is well written, well-typed, and has a very accessible API.
- It's so small that it's surface area is not intimidating.
- Even though it's main value proposition is
(Just an) array framework for Apple silicon
, it has so many good examples of how to perform common general-purpose ML tasks!
I wrote a small higher-order constrainer
function that returns a simpler constraint
function, which is what we pass to the LLM.
- The outer
constrainer
function is in charge of producing the pre-compiled decision tree. - The inner
constraint
function applies the constraints to the generation process depending on the current state of thepath
- It has access to the previous sampled token, a list of the top K probabilities and the logits at that particular step in the transformer
path
keeps track of the "prefix" that we are currently looking at, so that we can present to the LLM the possible choices
def constrainer():
decision_tree = T(
Statement(
OneOf(
JSXElement(
"div",
OneOf(
JSXText("Hi!"),
JSXExpression(
MembersOf(
"inventory",
Property(
"fruit",
ComputedProperty("1"),
ComputedProperty("2"),
ComputedProperty("3"),
),
),
),
),
),
)
),
close=tokenizer.eos_token,
)
tok_tree = compile(
decision_tree,
tokenize=lambda text: tokenizer.encode(
text=text, add_special_tokens=False, verbose=True
),
)
print("## Raw")
print("```json")
print(json.dumps(decision_tree, indent=2))
print("```")
print("## Compiled")
print("```json")
print(json.dumps(tok_tree, indent=2))
print("```")
path = tok_tree
def constraint(
logits: mx.array, token: mx.array, top_k: list[tuple[int, float]]
) -> mx.array:
nonlocal path
if token.size == 1:
tok = token.item()
if tok in path:
path = path[tok]
if len(path.keys()) > 0:
for key in path.keys():
logits[0, key] += top_k[0][1]
return logits
return constraint
>
is next
{}|
is end of branch
constraint: trie+
trie: int operator
operator: next | stop
next: ">"
stop: "{}|"
523>1538>28767>15359>28808>1867>1538>28767>2753>2{}|371>19930>842>28722>6424>733>28740>4709>443>1867>1538>28767>2753>2{}|28750>4709>443>1867>1538>28767>2753>2{}|28770>4709>443>1867>1538>28767>2753>2{}
Let's format it a bit, use newlines and some padding to denote branches to see it better.
523>1538>28767>
15359>28808>1867>1538>28767>2753>2
371>19930>842>28722>6424>733>
28740>4709>443>1867>1538>28767>2753>2
28750>4709>443>1867>1538>28767>2753>2
28770>4709>443>1867>1538>28767>2753>2
- I hate this, see how in the last major trie branch the only thing that is different is the prefix.
- Everyone ends with
2
because that's the token id for</s>
In sentencepiece, the end of the generation. - Further compression could be applied, taking common patterns and abstracting them
A: 4709>443>1867>1538>28767>2753>2
- I compressed it using zlib, got it from 200 chars to 129. Thanks GPT-4!
import zlib
import pickle
data = """523>1538>28767>15359>28808>1867>1538>28767>2753>2{}|371>19930>842>28722>6424>733>28740>4709>443>1867>1538>28767>2753>2{}|28750>4709>443>1"""
patterns = {
}
pattern_keys = {v: k for k, v in patterns.items()}
# Replace patterns in data
compressed_data = data
for pattern, key in pattern_keys.items():
compressed_data = compressed_data.replace(pattern, key)
# Step 2: Serialize the Compressed Data
serialized_data = {
"data": compressed_data,
"patterns": patterns
}
serialized = pickle.dumps(serialized_data)
# Step 3: Apply Standard Compression
final_compressed = zlib.compress(serialized)
print(f"Original size: {len(data)}")
print(f"Compressed size: {len(final_compressed)}")
# Decompression
decompressed = zlib.decompress(final_compressed)
deserialized = pickle.loads(decompressed)
# Replace keys with patterns
decompressed_data = deserialized['data']
for key, pattern in deserialized['patterns'].items():
decompressed_data = decompressed_data.replace(key, pattern)
# Verify
print(f"Is decompressed data same as original: {decompressed_data == data}")
Original size: 200
Compressed size: 129
Is decompressed data same as original: True
I then implemented a DSL on top of it, specifically geared towards JavaScript constraints.
- Super incomplete, but cool.
- Semantic layers, baby!
from .base import T, TokenTrie
def Statement(*next: TokenTrie):
return T(*next, close=";")
def OneOf(*next: TokenTrie):
return T(*next)
def MembersOf(key: str, *next: TokenTrie):
return T(key, *next)
def Property(key: str, *next: TokenTrie):
return T("." + key, *next)
def ComputedProperty(key: str, *next: TokenTrie):
return T(
"[" + key,
*next,
close="]",
)
def CallOf(key: str, *next: TokenTrie):
return T(key + "(", *next, close=")")
def StringLiteral(value: str):
return T('"', value, close='"')
def JSXElement(tag: str, *next: TokenTrie):
return T("<" + tag + ">", *next, close="</" + tag + ">")
def JSXText(value: str):
return T(value)
def JSXExpression(*next: TokenTrie):
return T("{", *next, close="}")