Created
July 30, 2025 22:04
-
-
Save wchargin/c89e67ce996b0fe8b13920b34e3ae5dc to your computer and use it in GitHub Desktop.
ergonomic, type-precise struct-of-arrays transform in Go
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Package soa implements a Structure-of-Arrays transform. | |
// | |
// This makes it convenient to convert a slice of values into a collection of | |
// correlated slices of individual fields, each representing one "column". For | |
// example, if you have a slice of N users, you can use [Collect] to extract a | |
// slice of N user IDs, a slice of N user names, and a slice of N user emails. | |
// Those slices can be conveniently passed as array parameters to SQL queries. | |
// | |
// See: https://en.wikipedia.org/wiki/AoS_and_SoA | |
// | |
// See: https://github.com/jackc/pgx/discussions/2359 | |
package soa | |
// Collect converts a slice of items into a slice of columns. Provide a | |
// callback that returns a list of fields of interest for a given item. | |
// Collect will return a slice of columns, each of length len(items), where | |
// columns[i][j] is the ith field of items[j]. Fields can be constructed by | |
// calling `V`. | |
// | |
// Across invocations of the callback, it should always return the same number | |
// of fields, and the field at any given index should always have the same | |
// static type parameter. Collect will panic if this is not the case. | |
// | |
// If the field at index i had type parameter E, then result[i] will have | |
// runtime type []E. For example, if `project(item)` returns a slice whose | |
// first element is `V(item.SomeString)`, then result[0] will be a []string | |
// that has the values of the SomeString field for each item. | |
// | |
// If items is empty, then columns will be nil and ok will be false, because | |
// there is no way invoke the callback to learn what the columns should be. | |
func Collect[T any](items []T, project func(item T) Fields) (columns []any, ok bool) { | |
if len(items) == 0 { | |
return nil, false | |
} | |
wr := &writer{firstPass: true, itemsLen: len(items)} | |
fields := project(items[0]) | |
wr.sinks = make([]any, len(fields)) | |
for _, f := range fields { | |
f.emit(wr) | |
} | |
wr.firstPass = false | |
wr.itemIndex++ | |
for _, item := range items[1:] { | |
wr.sinkIndex = 0 | |
fields = project(item) | |
if len(fields) != len(wr.sinks) { | |
panic(reasonFieldCount) | |
} | |
for _, f := range project(item) { | |
f.emit(wr) | |
} | |
wr.itemIndex++ | |
} | |
return wr.sinks, true | |
} | |
type field[E any] struct { | |
value E | |
} | |
// V bundles a field with its static type parameter. V should be called by the | |
// callback passed to Collect. | |
func V[E any](value E) Field { | |
return field[E]{value: value} | |
} | |
type writer struct { | |
sinks []any | |
sinkIndex int | |
firstPass bool | |
itemsLen int | |
itemIndex int | |
} | |
func (f field[E]) emit(wr *writer) { | |
var sink []E | |
if wr.firstPass { | |
sink = make([]E, wr.itemsLen) | |
wr.sinks[wr.sinkIndex] = sink | |
} else { | |
var ok bool | |
sink, ok = wr.sinks[wr.sinkIndex].([]E) | |
if !ok { | |
panic(reasonFieldTypes) | |
} | |
} | |
sink[wr.itemIndex] = f.value | |
wr.sinkIndex++ | |
} | |
// A Field is returned by V. See Collect for more details. | |
type Field interface { | |
emit(wr *writer) | |
} | |
type Fields []Field | |
type panicReason string | |
const ( | |
reasonFieldCount panicReason = "soa: inconsistent number of fields" | |
reasonFieldTypes = "soa: inconsistent type parameter for field" | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package soa | |
import ( | |
"slices" | |
"testing" | |
) | |
type user struct { | |
id int | |
name string | |
email string | |
} | |
func processUser(u *user) Fields { | |
return Fields{V(u.id), V(u.name), V(u.email)} | |
} | |
// If there are no items in the input, Collect should fail, because it cannot | |
// know how many columns to return and of what types. | |
func TestEmptyInput(t *testing.T) { | |
for _, test := range []struct { | |
name string | |
users []*user | |
}{ | |
{name: "nil", users: nil}, | |
{name: "allocated", users: make([]*user, 0, 4)}, | |
} { | |
t.Run(test.name, func(t *testing.T) { | |
result, ok := Collect(test.users, processUser) | |
if len(result) != 0 || ok { | |
t.Errorf("got %v, ok=%v; want empty, ok=false", result, ok) | |
} | |
}) | |
} | |
} | |
// If there are items in the input but the callback does not emit any columns, | |
// that's fine; this is not an error. | |
func TestNoFields(t *testing.T) { | |
users := makeUsers() | |
result, ok := Collect(users, func(u *user) Fields { return nil }) | |
if len(result) > 0 || !ok { | |
t.Fatalf("got %v, ok=%v; want empty, ok=true", result, ok) | |
} | |
} | |
func makeUsers() []*user { | |
return []*user{ | |
{id: 12, name: "Alice", email: "[email protected]"}, | |
{id: 34, name: "Bob", email: "[email protected]"}, | |
{id: 56, name: "Camille", email: "[email protected]"}, | |
{id: 78, name: "Divya", email: "[email protected]"}, | |
} | |
} | |
func checkUsersSoA(t *testing.T, users []*user, result []any) { | |
if len(result) != 3 { | |
t.Errorf("result: want len=3, got: %v", result) | |
return | |
} | |
// Run the SoA transform manually to compute expected values. | |
wantIDs := make([]int, len(users)) | |
wantNames := make([]string, len(users)) | |
wantEmails := make([]string, len(users)) | |
for i, user := range users { | |
wantIDs[i] = user.id | |
wantNames[i] = user.name | |
wantEmails[i] = user.email | |
} | |
checkSlicesEqual(t, "IDs", result[0], wantIDs) | |
checkSlicesEqual(t, "names", result[1], wantNames) | |
checkSlicesEqual(t, "emails", result[2], wantEmails) | |
} | |
// Checks that gotAny is a []T whose elements equal those in want. | |
func checkSlicesEqual[T comparable](t *testing.T, prefix string, gotAny any, want []T) { | |
got, ok := gotAny.([]T) | |
if !ok { | |
t.Errorf("%s: got %T, want %T: got %v, want %v", prefix, gotAny, want, gotAny, want) | |
return | |
} | |
if !slices.Equal(got, want) { | |
t.Errorf("%s: got %v, want %v", prefix, got, want) | |
} | |
} | |
func TestHappyPath(t *testing.T) { | |
users := makeUsers() | |
result, ok := Collect(users, processUser) | |
if !ok { | |
t.Errorf("got !ok, want ok") | |
} | |
checkUsersSoA(t, users, result) | |
} | |
func TestSingleItem(t *testing.T) { | |
users := makeUsers()[:1] | |
result, ok := Collect(users, processUser) | |
if !ok { | |
t.Errorf("got !ok, want ok") | |
} | |
checkUsersSoA(t, users, result) | |
} | |
// TestFromScalar makes sure that the types are set up such that you can run | |
// Collect even on values of non-pointer type. | |
func TestFromScalar(t *testing.T) { | |
ints := []int64{1, 2, 3} | |
result, ok := Collect(ints, func(n int64) Fields { | |
return Fields{V(n), V(n * n), V(n * n * n), V(n * n * n * n)} | |
}) | |
if !ok { | |
t.Errorf("got !ok, want ok") | |
} | |
if len(result) != 4 { | |
t.Fatalf("result: want len=4, got: %v", result) | |
return | |
} | |
checkSlicesEqual(t, "originals", result[0], []int64{1, 2, 3}) | |
checkSlicesEqual(t, "squares", result[1], []int64{1, 4, 9}) | |
checkSlicesEqual(t, "cubes", result[2], []int64{1, 8, 27}) | |
checkSlicesEqual(t, "hypercubes", result[3], []int64{1, 16, 81}) | |
} | |
// getPanicReason runs f. If f panics, it returns the panic reason (the | |
// argument passed to `panic`). Otherwise, it returns nil. | |
func getPanicReason(f func()) (reason any) { | |
defer func() { | |
reason = recover() | |
}() | |
f() | |
return | |
} | |
func TestPanicIfFieldsIncrease(t *testing.T) { | |
reason := getPanicReason(func() { | |
first := true | |
Collect([]int{1, 2}, func(_ int) Fields { | |
if first { | |
first = false | |
return Fields{V("a")} | |
} else { | |
return Fields{V("a"), V("b")} | |
} | |
}) | |
}) | |
if reason != reasonFieldCount { | |
t.Errorf("panic reason: got %v, want %v", reason, reasonFieldCount) | |
} | |
} | |
func TestPanicIfFieldsDecrease(t *testing.T) { | |
reason := getPanicReason(func() { | |
first := true | |
Collect([]int{1, 2}, func(_ int) Fields { | |
if first { | |
first = false | |
return Fields{V("a"), V("b")} | |
} else { | |
return Fields{V("a")} | |
} | |
}) | |
}) | |
if reason != reasonFieldCount { | |
t.Errorf("panic reason: got %v, want %v", reason, reasonFieldCount) | |
} | |
} | |
func TestPanicIfFieldsChangeStaticType(t *testing.T) { | |
reason := getPanicReason(func() { | |
first := true | |
Collect([]int{1, 2}, func(_ int) Fields { | |
if first { | |
first = false | |
return Fields{V(any("a"))} | |
} else { | |
return Fields{V("b")} | |
} | |
}) | |
}) | |
if reason != reasonFieldTypes { | |
t.Errorf("panic reason: got %v, want %v", reason, reasonFieldTypes) | |
} | |
} | |
type intBox int | |
type stringBox string | |
type box interface { | |
Unbox() any | |
} | |
func (i intBox) Unbox() any { | |
return int(i) | |
} | |
func (s stringBox) Unbox() any { | |
return string(s) | |
} | |
// The concrete type of a field is allowed to differ across invocations, as | |
// long as the static type is the same. | |
func TestRuntimeTypeMayDiffer(t *testing.T) { | |
type s struct { | |
n int | |
s string | |
useInt bool | |
} | |
items := []s{ | |
{n: 1, s: "a", useInt: false}, | |
{n: 2, s: "b", useInt: true}, | |
{n: 3, s: "c", useInt: true}, | |
{n: 4, s: "d", useInt: false}, | |
} | |
result, ok := Collect(items, func(item s) Fields { | |
var f box | |
if item.useInt { | |
f = intBox(item.n) | |
} else { | |
f = stringBox(item.s) | |
} | |
return Fields{V(f)} | |
}) | |
if !ok { | |
t.Errorf("got !ok, want ok") | |
} | |
want := []box{stringBox("a"), intBox(2), intBox(3), stringBox("d")} | |
checkSlicesEqual(t, "values", result[0], want) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment