Skip to content

Instantly share code, notes, and snippets.

@rem1niscence
Created March 6, 2025 16:48
Show Gist options
  • Save rem1niscence/3e60ae0af6a0e80be4c99981687a22b3 to your computer and use it in GitHub Desktop.
Save rem1niscence/3e60ae0af6a0e80be4c99981687a22b3 to your computer and use it in GitHub Desktop.
func newTestNode(k []int, value, leftChildKey, rightChildKey []byte) *node {
return &node{
Key: &key{leastSigBits: k},
Node: lib.Node{
Value: value,
LeftChildKey: leftChildKey,
RightChildKey: rightChildKey,
},
}
}
func TestSet(t *testing.T) {
tests := []struct {
name string
detail string
keyBitSize int
preset *NodeList
expected *NodeList
rootKey []byte
targetKey []byte
targetValue []byte
}{
{
name: "update and target at 010",
detail: `BEFORE: root
/ \
0 1
/ \ / \
000 010 101 111
AFTER: root
/ \
0 1
/ \ / \
000 *010* 101 111
`,
keyBitSize: 3,
rootKey: []byte{0b10010000}, // arbitrary
targetKey: []byte{1}, // hashes to [010]
targetValue: []byte("some_value"),
preset: &NodeList{
Nodes: []*node{
newTestNode(
[]int{1, 0, 0}, nil, []byte{0b0, 0}, []byte{0b1, 0},
),
// 0 000 111
newTestNode([]int{0}, nil, []byte{0b0, 2}, []byte{0b10, 1}),
// 1 111 101
newTestNode([]int{1}, nil, []byte{0b111, 0}, []byte{0b101, 0}),
// 000 (leaf)
newTestNode([]int{0, 0, 0}, nil, nil, nil),
// 010 (leaf)
newTestNode([]int{0, 1, 0}, nil, nil, nil),
// 111
newTestNode([]int{1, 1, 1}, nil, nil, nil),
// 101
newTestNode([]int{1, 0, 1}, nil, nil, nil),
},
},
expected: &NodeList{
Nodes: []*node{
// 0 000 010
newTestNode([]int{0}, nil, []byte{0b0, 2}, []byte{0b10, 1}),
// 000
newTestNode([]int{0, 0, 0}, nil, nil, nil),
// 1 111 101
newTestNode([]int{1}, nil, []byte{0b111, 0}, []byte{0b101, 0}),
// 010 (updated)
newTestNode([]int{0, 1, 0}, []byte("some_value"), nil, nil),
// 100 root
newTestNode(
[]int{1, 0, 0},
func() []byte {
// NOTE: the tree values on the right side are nulled, so the inputs for the right side are incomplete
// grandchildren
input000, input010 := []byte{0b0, 2}, append([]byte{0b10, 1}, crypto.Hash([]byte("some_value"))...)
// children
input0 := append([]byte{0b0, 0}, crypto.Hash(append(input000, input010...))...)
input1 := append([]byte{0b1, 0}, []byte{}...)
// root value
return crypto.Hash(append(input0, input1...))
}(),
[]byte{0b0, 0},
[]byte{0b1, 0},
),
// 101
newTestNode([]int{1, 0, 1}, nil, nil, nil),
// 111
newTestNode([]int{1, 1, 1}, nil, nil, nil),
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
func() {
smt, memStore := NewTestSMT(t, test.preset, nil, test.keyBitSize)
defer memStore.Close()
// execute the traversal code
require.NoError(t, smt.Set(test.targetKey, test.targetValue))
// create an iterator to check out the values of the store
it, err := memStore.Iterator(nil)
require.NoError(t, err)
defer it.Close()
// iterate through the database
for i := 0; it.Valid(); func() { it.Next(); i++ }() {
got := newNode()
// convert the value to a node
require.NoError(t, lib.Unmarshal(it.Value(), &got.Node))
// convert the key to a node key
got.Key.fromBytes(it.Key())
// compare got vs expected
//fmt.Printf("%08b %v\n", got.Key.mostSigBytes, got.Key.leastSigBits)
require.Equal(t, test.expected.Nodes[i].Key.bytes(), got.Key.bytes(), fmt.Sprintf("Iteration: %d on node %v", i, got.Key.leastSigBits))
require.Equal(t, test.expected.Nodes[i].LeftChildKey, got.LeftChildKey, fmt.Sprintf("Iteration: %d on node %v", i, got.Key.leastSigBits))
require.Equal(t, test.expected.Nodes[i].RightChildKey, got.RightChildKey, fmt.Sprintf("Iteration: %d on node %v", i, got.Key.leastSigBits))
// check root value (this allows quick verification of the hashing up logic without actually needing to fill in and check every value)
if bytes.Equal(got.Key.bytes(), smt.root.Key.bytes()) {
require.Equal(t, test.expected.Nodes[i].Value, got.Value)
}
}
}()
})
}
}
func TestSet(t *testing.T) {
tests := []struct {
name string
detail string
keyBitSize int
preset *NodeList
expected *NodeList
rootKey []byte
targetKey []byte
targetValue []byte
}{
{
name: "update and target at 010",
detail: `BEFORE: root
/ \
0 1
/ \ / \
000 010 101 111
AFTER: root
/ \
0 1
/ \ / \
000 *010* 101 111
`,
keyBitSize: 3,
rootKey: []byte{0b10010000}, // arbitrary
targetKey: []byte{1}, // hashes to [010]
targetValue: []byte("some_value"),
preset: &NodeList{
Nodes: []*node{
{ // root
Key: &key{leastSigBits: []int{1, 0, 0}}, // arbitrary
Node: lib.Node{
LeftChildKey: []byte{0b0, 0}, // 0
RightChildKey: []byte{0b1, 0}, // 1
},
},
{ // 0
Key: &key{leastSigBits: []int{0}},
Node: lib.Node{
LeftChildKey: []byte{0b0, 2}, // 000
RightChildKey: []byte{0b10, 1}, // 010
},
},
{ // 1
Key: &key{leastSigBits: []int{1}},
Node: lib.Node{
LeftChildKey: []byte{0b111, 0}, // 111
RightChildKey: []byte{0b101, 0}, // 101
},
},
{ // 000
Key: &key{leastSigBits: []int{0, 0, 0}},
Node: lib.Node{}, // leaf
},
{ // 010
Key: &key{leastSigBits: []int{0, 1, 0}},
Node: lib.Node{}, // leaf
},
{ // 111
Key: &key{leastSigBits: []int{1, 1, 1}},
Node: lib.Node{}, // leaf
},
{ // 101
Key: &key{leastSigBits: []int{1, 0, 1}},
Node: lib.Node{}, // leaf
},
},
},
expected: &NodeList{
Nodes: []*node{
{ // 0
Key: &key{leastSigBits: []int{0}},
Node: lib.Node{
LeftChildKey: []byte{0b0, 2}, // 000
RightChildKey: []byte{0b10, 1}, // 010
},
},
{ // 000
Key: &key{leastSigBits: []int{0, 0, 0}},
Node: lib.Node{}, // leaf
},
{ // 1
Key: &key{leastSigBits: []int{1}},
Node: lib.Node{
LeftChildKey: []byte{0b111, 0}, // 111
RightChildKey: []byte{0b101, 0}, // 101
},
},
{ // 010 (updated)
Key: &key{leastSigBits: []int{0, 1, 0}},
Node: lib.Node{Value: []byte("some_value")}, // leaf
},
{ // 100 root
Key: &key{leastSigBits: []int{1, 0, 0}}, // arbitrary
Node: lib.Node{
Value: func() []byte {
// NOTE: the tree values on the right side are nulled, so the inputs for the right side are incomplete
// grandchildren
input000, input010 := []byte{0b0, 2}, append([]byte{0b10, 1}, crypto.Hash([]byte("some_value"))...)
// children
input0 := append([]byte{0b0, 0}, crypto.Hash(append(input000, input010...))...)
input1 := append([]byte{0b1, 0})
// root value
return crypto.Hash(append(input0, input1...))
}(),
LeftChildKey: []byte{0b0, 0}, // 0
RightChildKey: []byte{0b1, 0}, // 1
},
},
{ // 101
Key: &key{leastSigBits: []int{1, 0, 1}},
Node: lib.Node{}, // leaf
},
{ // 111
Key: &key{leastSigBits: []int{1, 1, 1}},
Node: lib.Node{}, // leaf
},
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
func() {
smt, memStore := NewTestSMT(t, test.preset, nil, test.keyBitSize)
defer memStore.Close()
// execute the traversal code
require.NoError(t, smt.Set(test.targetKey, test.targetValue))
// create an iterator to check out the values of the store
it, err := memStore.Iterator(nil)
require.NoError(t, err)
defer it.Close()
// iterate through the database
for i := 0; it.Valid(); func() { it.Next(); i++ }() {
got := newNode()
// convert the value to a node
require.NoError(t, lib.Unmarshal(it.Value(), &got.Node))
// convert the key to a node key
got.Key.fromBytes(it.Key())
// compare got vs expected
//fmt.Printf("%08b %v\n", got.Key.mostSigBytes, got.Key.leastSigBits)
require.Equal(t, test.expected.Nodes[i].Key.bytes(), got.Key.bytes(), fmt.Sprintf("Iteration: %d on node %v", i, got.Key.leastSigBits))
require.Equal(t, test.expected.Nodes[i].LeftChildKey, got.LeftChildKey, fmt.Sprintf("Iteration: %d on node %v", i, got.Key.leastSigBits))
require.Equal(t, test.expected.Nodes[i].RightChildKey, got.RightChildKey, fmt.Sprintf("Iteration: %d on node %v", i, got.Key.leastSigBits))
// check root value (this allows quick verification of the hashing up logic without actually needing to fill in and check every value)
if bytes.Equal(got.Key.bytes(), smt.root.Key.bytes()) {
require.Equal(t, test.expected.Nodes[i].Value, got.Value)
}
}
}()
})
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment