Skip to content

Instantly share code, notes, and snippets.

@lithdew
Last active January 27, 2021 08:49
Show Gist options
  • Save lithdew/be0be2dc480cfba98a4180be39bf804a to your computer and use it in GitHub Desktop.
Save lithdew/be0be2dc480cfba98a4180be39bf804a to your computer and use it in GitHub Desktop.
zig: trie for routing
const std = @import("std");
const mem = std.mem;
const fmt = std.fmt;
const sort = std.sort;
const testing = std.testing;
pub fn Trie(comptime S: []const u8, comptime V: type) type {
return struct {
const Self = @This();
inline fn isWildcard(key: []const u8) bool {
return mem.startsWith(u8, key, ":") or mem.eql(u8, key, "*");
}
const Node = struct {
key: []const u8,
value: ?V = null,
min_len: usize = 0,
max_len: usize = 0,
index: []usize = &[_]usize{},
wildcard: ?*Node = null,
children: []*Node = &[_]*Node{},
inline fn lessThan(context: void, a: *Node, b: *Node) bool {
return a.key.len < b.key.len;
}
fn add(comptime self: *Node, child: *Node) void {
if (isWildcard(child.key)) {
self.wildcard = child;
return;
}
var children: [self.children.len + 1]*Node = undefined;
mem.copy(*Node, &children, self.children ++ [_]*Node{child});
sort.sort(*Node, &children, {}, Node.lessThan);
self.min_len = children[0].key.len;
self.max_len = children[children.len - 1].key.len;
var index: [self.max_len + 1]usize = undefined;
var len: usize = 0;
var i: usize = 0;
while (len <= self.max_len) : (len += 1) {
while (len > children[i].key.len) {
i += 1;
}
index[len] = i;
}
self.index = &index;
self.children = &children;
}
fn get(self: *const Node, key: []const u8) ?*Node {
if (isWildcard(key)) return self.wildcard;
if (key.len < self.min_len or key.len > self.max_len) {
return null;
}
var i = self.index[key.len];
while (true) {
const node = self.children[i];
if (node.key.len != key.len) {
return null;
}
if (mem.eql(u8, node.key, key)) {
return node;
}
i += 1;
if (i >= self.children.len) {
return null;
}
}
}
inline fn resolve(self: *const Node, key: []const u8) ?*Node {
return self.get(key) orelse self.wildcard;
}
};
root: Node = .{ .key = "" },
max_params: usize = 0,
pub fn put(comptime self: *Self, key: []const u8, value: V) void {
if (key.len == 0) {
@compileError("Key is empty.");
}
var params_count: usize = 0;
var child: *Node = &self.root;
var it = mem.tokenize(key, S);
while (it.next()) |segment| {
child = child.get(segment) orelse new_child: {
var new_node = Node{ .key = segment };
child.add(&new_node);
break :new_child &new_node;
};
if (isWildcard(child.key)) {
params_count += 1;
}
}
child.value = value;
if (self.max_params < params_count) {
self.max_params = params_count;
}
}
pub fn compile(comptime self: *const Self) type {
return struct {
pub const max_params = self.max_params;
pub const Result = struct {
value: V,
params: []const []const u8,
pub fn format(
result: *const Result,
comptime layout: []const u8,
options: fmt.FormatOptions,
writer: anytype,
) !void {
try fmt.format(writer, "(result: {}, params: {s})", .{
result.value,
result.params,
});
}
};
root: Node = self.root,
pub fn get(
trie: *const @This(),
key: []const u8,
params: [][]const u8,
) ?Result {
var params_count: usize = 0;
var child: *const Node = &trie.root;
var it = mem.tokenize(key, S);
while (it.next()) |segment| {
child = child.resolve(segment) orelse return null;
if (isWildcard(child.key)) {
params[params_count] = segment;
params_count += 1;
}
}
const value = child.value orelse return null;
return Result{
.params = params[0..params_count],
.value = value,
};
}
};
}
};
}
test "Trie" {
const Router = comptime trie: {
var trie: Trie("/", usize) = .{};
trie.put("/posts/list", 0);
trie.put("/*/hello", 99);
trie.put("/posts/list", 1);
trie.put("/other/list", 2);
trie.put("/posts/:id/delete", 3);
trie.put("/posts/:id/comments", 4);
trie.put("/posts/:id", 5);
break :trie trie.compile();
};
var params: [Router.max_params][]const u8 = undefined;
var router: Router = .{};
const Test = struct {
path: []const u8,
expected: ?struct {
value: usize,
params: []const []const u8 = &[_][]const u8{},
},
};
inline for (.{
Test{ .path = "", .expected = null },
Test{ .path = "/", .expected = null },
Test{ .path = "/posts", .expected = null },
Test{ .path = "/posts/", .expected = null },
Test{ .path = "/others", .expected = null },
Test{ .path = "/others/", .expected = null },
Test{ .path = "/posts/list", .expected = .{ .value = 1 } },
Test{ .path = "/other/list", .expected = .{ .value = 2 } },
Test{ .path = "/posts/ok", .expected = .{ .value = 5, .params = &[_][]const u8{"ok"} } },
Test{ .path = "/posts/lis", .expected = .{ .value = 5, .params = &[_][]const u8{"lis"} } },
Test{ .path = "/world/hello", .expected = .{ .value = 99, .params = &[_][]const u8{"world"} } },
Test{ .path = "/posts/xd/comments", .expected = .{ .value = 4, .params = &[_][]const u8{"xd"} } },
Test{ .path = "/posts/test/delete", .expected = .{ .value = 3, .params = &[_][]const u8{"test"} } },
}) |case| {
if (case.expected) |expected| {
const actual = router.get(case.path, &params).?;
testing.expectEqual(expected.value, actual.value);
testing.expectEqual(expected.params.len, actual.params.len);
var i: usize = 0;
while (i < expected.params.len) : (i += 1) {
testing.expectEqualStrings(expected.params[i], actual.params[i]);
}
} else {
testing.expectEqual(@as(?Router.Result, null), router.get(case.path, &params));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment