Last active
April 13, 2023 17:26
-
-
Save davidbalbert/2d326ae2827fc948a47b5eaa15c1620d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func commonPrefixLen(a, b string) int { | |
i := 0 | |
for ; i < len(a) && i < len(b); i++ { | |
if a[i] != b[i] { | |
break | |
} | |
} | |
return i | |
} | |
type edge struct { | |
label string | |
node *node | |
} | |
type node struct { | |
hasValue bool | |
value any | |
edgeIndex []byte | |
edges []*edge | |
} | |
func (n *node) store(key string, value any) { | |
for { | |
if len(key) == 0 { | |
n.hasValue = true | |
n.value = value | |
return | |
} | |
i := sort.Search(len(n.edgeIndex), func(i int) bool { | |
return n.edgeIndex[i] >= key[0] | |
}) | |
if i < len(n.edgeIndex) && n.edgeIndex[i] == key[0] { | |
// edge found | |
e := n.edges[i] | |
prefixLen := commonPrefixLen(e.label, key) | |
if prefixLen == len(e.label) && prefixLen == len(key) { | |
// exact match, overwrite | |
e.node.hasValue = true | |
e.node.value = value | |
return | |
} else if prefixLen == len(e.label) { | |
// e.label is a prefix of key | |
key = key[prefixLen:] | |
n = e.node | |
} else { | |
// prefixLen < len(n.label) && prefixLen < len(key) | |
// split | |
intermediateNode := &node{ | |
edgeIndex: []byte{e.label[prefixLen]}, | |
edges: []*edge{{label: e.label[prefixLen:], node: e.node}}, | |
} | |
e.label = e.label[:prefixLen] | |
e.node = intermediateNode | |
key = key[prefixLen:] | |
n = intermediateNode | |
} | |
} else if i < len(n.edgeIndex) { | |
// insert edge | |
n.edgeIndex = append(n.edgeIndex, 0) | |
copy(n.edgeIndex[i+1:], n.edgeIndex[i:]) | |
n.edgeIndex[i] = key[0] | |
n.edges = append(n.edges, nil) | |
copy(n.edges[i+1:], n.edges[i:]) | |
n.edges[i] = &edge{label: key, node: &node{hasValue: true, value: value}} | |
return | |
} else { | |
// append edge | |
n.edgeIndex = append(n.edgeIndex, key[0]) | |
n.edges = append(n.edges, &edge{label: key, node: &node{hasValue: true, value: value}}) | |
return | |
} | |
} | |
} | |
func (n *node) load(key string) (any, bool) { | |
if len(key) == 0 { | |
return n.value, n.hasValue | |
} | |
for { | |
i := sort.Search(len(n.edgeIndex), func(i int) bool { | |
return n.edgeIndex[i] >= key[0] | |
}) | |
if i < len(n.edgeIndex) && n.edgeIndex[i] == key[0] { | |
// edge found | |
e := n.edges[i] | |
prefixLen := radixTreeCommonPrefixLen(e.label, key) | |
if prefixLen == len(e.label) && prefixLen == len(key) { | |
// exact match | |
return e.node.value, e.node.hasValue | |
} else if prefixLen == len(e.label) { | |
// e.label is a prefix of key | |
key = key[prefixLen:] | |
n = e.node | |
} else { | |
// prefixLen < len(n.label) && prefixLen < len(key) | |
return nil, false | |
} | |
} else { | |
// no edge found | |
return nil, false | |
} | |
} | |
} | |
type walkPartialTokensFunc func(key string, value any) error | |
// walkPartialTokens tokenizes keys in the tree using sep as a separator, and calls fn for each | |
// key that matches the query. The query is tokenized using the same separator, and each token | |
// in the query must be a prefix of a corresponding token in the key. The number of tokens in | |
// each matched key must match the number of tokens in the query. | |
// | |
// E.g. if sep is ' ', then the query "fo ba" will match the keys "foo bar" and "foo baz", but not | |
// "foo bar baz". As a special case, a query of "" will match the key "", and nothing else, for any | |
// value of sep. | |
func (root *node) walkPartialTokens(query string, sep byte, fn walkPartialTokensFunc) error { | |
queryParts := strings.FieldsFunc(query, func(r rune) bool { | |
return r == rune(sep) | |
}) | |
// special case: if the query is empty, we match the key "". | |
if len(queryParts) == 0 { | |
if root.hasValue { | |
return fn("", root.value) | |
} | |
return nil | |
} | |
var walkNode func(prefix string, n *node, tokPrefix string, tokPrefixes []string) error | |
var walkEdge func(prefix string, e *edge, offset int, tokPrefix string, tokPrefixes []string) error | |
var walkUntilSep func(prefix string, e *edge, offset int, tokPrefixes []string) error | |
walkNode = func(prefix string, n *node, tokPrefix string, tokPrefixes []string) error { | |
// walkNode is always called with len(tokPrefix) > 0 | |
i := sort.Search(len(n.edgeIndex), func(i int) bool { | |
return n.edgeIndex[i] >= tokPrefix[0] | |
}) | |
if i == len(n.edgeIndex) || n.edgeIndex[i] != tokPrefix[0] { | |
// no edge found | |
return nil | |
} | |
edge := n.edges[i] | |
return walkEdge(prefix, edge, 0, tokPrefix, tokPrefixes) | |
} | |
walkEdge = func(prefix string, e *edge, offset int, partialToken string, partialTokens []string) error { | |
rest := e.label[offset:] | |
prefixLen := radixTreeCommonPrefixLen(rest, partialToken) | |
if prefixLen < len(partialToken) && prefixLen < len(rest) { | |
// neither the edge nor partialToken is a prefix of the other. no match. | |
return nil | |
} else if prefixLen < len(partialToken) { | |
// partialToken continues past the end of the edge (i.e. rest is a prefix of partialToken). | |
// Keep searching at the next node. partialToken[prefixLen:] is guaranteed to be non-empty. | |
return walkNode(prefix+rest, e.node, partialToken[prefixLen:], partialTokens) | |
} else if prefixLen < len(rest) { | |
// partialToken ends inside the edge (i.e. partialToken is a prefix of rest). | |
// Start searching for separator on this edge. | |
return walkUntilSep(prefix+rest[:prefixLen], e, offset+prefixLen, partialTokens) | |
} else { | |
// partialToken == rest | |
// Start searching for separator starting at the next node. | |
node := e.node | |
if node.hasValue && len(partialTokens) == 0 { | |
err := fn(prefix+rest, node.value) | |
if err != nil { | |
return err | |
} | |
} | |
for _, e := range node.edges { | |
err := walkUntilSep(prefix+rest, e, 0, partialTokens) | |
if err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
} | |
walkUntilSep = func(prefix string, e *edge, offset int, partialTokens []string) error { | |
suffix := e.label[offset:] | |
i := strings.Index(suffix, string(sep)) | |
if i == -1 { | |
// no separator | |
if len(partialTokens) == 0 { | |
// no more partial tokens, so we've found a match | |
if e.node.hasValue { | |
err := fn(prefix+suffix, e.node.value) | |
if err != nil { | |
return err | |
} | |
} | |
} | |
for _, e := range e.node.edges { | |
err := walkUntilSep(prefix+suffix, e, 0, partialTokens) | |
if err != nil { | |
return err | |
} | |
} | |
return nil | |
} else if len(partialTokens) == 0 { | |
// we found a separator on this edge, but have no more partial tokens, so stop here | |
return nil | |
} else if i == len(suffix)-1 { | |
return walkNode(prefix+suffix, e.node, partialTokens[0], partialTokens[1:]) | |
} else { | |
return walkEdge(prefix+suffix[:i+1], e, offset+i+1, partialTokens[0], partialTokens[1:]) | |
} | |
} | |
return walkNode("", root, queryParts[0], queryParts[1:]) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"reflect" | |
"testing" | |
) | |
func TestStoreAndLoad(t *testing.T) { | |
n := &node{} | |
n.store("foobar", 1) | |
n.store("foo", 2) | |
n.store("bar", 3) | |
value, ok := n.load("foobar") | |
if !ok { | |
t.Fatal("expected foobar to be found") | |
} else if value != 1 { | |
t.Fatalf("expected foobar to be 1, got %d", value) | |
} | |
value, ok = n.load("foo") | |
if !ok { | |
t.Fatal("expected foo to be found") | |
} else if value != 2 { | |
t.Fatalf("expected foo to be 2, got %d", value) | |
} | |
value, ok = n.load("bar") | |
if !ok { | |
t.Fatal("expected bar to be found") | |
} else if value != 3 { | |
t.Fatalf("expected bar to be 3, got %d", value) | |
} | |
} | |
func TestLoadNotFound(t *testing.T) { | |
n := &node{} | |
n.store("foobar", 1) | |
n.store("foo", 2) | |
_, ok := n.load("fo") | |
if ok { | |
t.Fatal("expected fo to not be found") | |
} | |
_, ok = n.load("foob") | |
if ok { | |
t.Fatal("expected foob to not be found") | |
} | |
_, ok = n.load("bar") | |
if ok { | |
t.Fatal("expected bar to not be found") | |
} | |
} | |
func TestStoreAndLoadEmptyKey(t *testing.T) { | |
n := &node{} | |
n.store("", 1) | |
value, ok := n.load("") | |
if !ok { | |
t.Fatal("expected empty key to be found") | |
} else if value != 1 { | |
t.Fatalf("expected empty key to be 1, got %d", value) | |
} | |
} | |
func TestNonExistantEmptyKeyLoad(t *testing.T) { | |
n := &node{} | |
_, ok := n.load("") | |
if ok { | |
t.Fatal("expected empty key to not be found") | |
} | |
n.store("foo", 1) | |
_, ok = n.load("") | |
if ok { | |
t.Fatal("expected empty key to not be found") | |
} | |
} | |
func TestOverwrite(t *testing.T) { | |
n := &node{} | |
n.store("foo", 1) | |
value, ok := n.load("foo") | |
if !ok { | |
t.Fatal("expected foo to be found") | |
} else if value != 1 { | |
t.Fatalf("expected foo to be 1, got %d", value) | |
} | |
n.store("foo", 2) | |
value, ok = n.load("foo") | |
if !ok { | |
t.Fatal("expected foo to be found") | |
} else if value != 2 { | |
t.Fatalf("expected foo to be 2, got %d", value) | |
} | |
} | |
func TestWalkPartialTokensExactMatch(t *testing.T) { | |
n := &node{} | |
n.store("show version", 1) | |
n.store("show version detail", 2) | |
n.store("show name", 3) | |
n.store("show number", 4) | |
var keys []string | |
var values []int | |
n.walkPartialTokens("show version", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string{"show version"} | |
expectedValues := []int{1} | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} | |
func TestWalkPartialTokensPrefixMatch(t *testing.T) { | |
n := &node{} | |
n.store("show version", 1) | |
n.store("show version detail", 2) | |
n.store("show name", 3) | |
n.store("show number", 4) | |
var keys []string | |
var values []int | |
n.walkPartialTokens("sh ver", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string{"show version"} | |
expectedValues := []int{1} | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} | |
func TestWalkPartialTokensMultipleMatches(t *testing.T) { | |
n := &node{} | |
n.store("show version", 1) | |
n.store("show version detail", 2) | |
n.store("show name", 3) | |
n.store("show number", 4) | |
var keys []string | |
var values []int | |
n.walkPartialTokens("sh n", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string{"show name", "show number"} | |
expectedValues := []int{3, 4} | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
keys = nil | |
values = nil | |
n.walkPartialTokens("sh na", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys = []string{"show name"} | |
expectedValues = []int{3} | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
keys = nil | |
values = nil | |
n.walkPartialTokens("sh nu", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys = []string{"show number"} | |
expectedValues = []int{4} | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} | |
func TestWalkPartialTokensEdgeDoesntMatch(t *testing.T) { | |
n := &node{} | |
n.store("show version", 1) | |
n.store("show version detail", 2) | |
n.store("show name", 3) | |
n.store("show number", 4) | |
var keys []string | |
var values []int | |
n.walkPartialTokens("shaw", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string(nil) | |
expectedValues := []int(nil) | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} | |
func TestWalkPartialTokensTooFewTokens(t *testing.T) { | |
n := &node{} | |
n.store("show version", 1) | |
n.store("show version detail", 2) | |
n.store("show name", 3) | |
n.store("show number", 4) | |
var keys []string | |
var values []int | |
n.walkPartialTokens("sh", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string(nil) | |
expectedValues := []int(nil) | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} | |
func TestWalkPartialTokensTooMany(t *testing.T) { | |
n := &node{} | |
n.store("show version", 1) | |
n.store("show version detail", 2) | |
n.store("show name", 3) | |
n.store("show number", 4) | |
var keys []string | |
var values []int | |
n.walkPartialTokens("sh na foo", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string(nil) | |
expectedValues := []int(nil) | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} | |
func TestWalkPartialTokensEmptyQuery(t *testing.T) { | |
n := &node{} | |
n.store("", 1) | |
var keys []string | |
var values []int | |
n.walkPartialTokens("", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string{""} | |
expectedValues := []int{1} | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} | |
func TestWalkPartialTokensEmptyQueryNoMatch(t *testing.T) { | |
n := &node{} | |
n.store("show version", 1) | |
var keys []string | |
var values []int | |
n.walkPartialTokens("", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string(nil) | |
expectedValues := []int(nil) | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} | |
func TestWalkPartialTokensEmptyTree(t *testing.T) { | |
n := &node{} | |
var keys []string | |
var values []int | |
n.walkPartialTokens("sh ver", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string(nil) | |
expectedValues := []int(nil) | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} | |
func TestWalkPartialTokensSeparatorWithinEdge(t *testing.T) { | |
n := &node{} | |
n.store("foo bar baz", 1) | |
n.store("foo bar", 2) | |
var keys []string | |
var values []int | |
n.walkPartialTokens("foo bar", ' ', func(prefix string, value any) error { | |
keys = append(keys, prefix) | |
values = append(values, value.(int)) | |
return nil | |
}) | |
expectedKeys := []string{"foo bar"} | |
expectedValues := []int{2} | |
if !reflect.DeepEqual(keys, expectedKeys) { | |
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys) | |
} | |
if !reflect.DeepEqual(values, expectedValues) { | |
t.Fatalf("expected values %#v, got %#v", expectedValues, values) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment