Skip to content

Instantly share code, notes, and snippets.

@wchargin
Created July 30, 2025 22:04
Show Gist options
  • Save wchargin/c89e67ce996b0fe8b13920b34e3ae5dc to your computer and use it in GitHub Desktop.
Save wchargin/c89e67ce996b0fe8b13920b34e3ae5dc to your computer and use it in GitHub Desktop.
ergonomic, type-precise struct-of-arrays transform in Go
// Package soa implements a Structure-of-Arrays transform.
//
// This makes it convenient to convert a slice of values into a collection of
// correlated slices of individual fields, each representing one "column". For
// example, if you have a slice of N users, you can use [Collect] to extract a
// slice of N user IDs, a slice of N user names, and a slice of N user emails.
// Those slices can be conveniently passed as array parameters to SQL queries.
//
// See: https://en.wikipedia.org/wiki/AoS_and_SoA
//
// See: https://github.com/jackc/pgx/discussions/2359
package soa
// Collect converts a slice of items into a slice of columns. Provide a
// callback that returns a list of fields of interest for a given item.
// Collect will return a slice of columns, each of length len(items), where
// columns[i][j] is the ith field of items[j]. Fields can be constructed by
// calling `V`.
//
// Across invocations of the callback, it should always return the same number
// of fields, and the field at any given index should always have the same
// static type parameter. Collect will panic if this is not the case.
//
// If the field at index i had type parameter E, then result[i] will have
// runtime type []E. For example, if `project(item)` returns a slice whose
// first element is `V(item.SomeString)`, then result[0] will be a []string
// that has the values of the SomeString field for each item.
//
// If items is empty, then columns will be nil and ok will be false, because
// there is no way invoke the callback to learn what the columns should be.
func Collect[T any](items []T, project func(item T) Fields) (columns []any, ok bool) {
if len(items) == 0 {
return nil, false
}
wr := &writer{firstPass: true, itemsLen: len(items)}
fields := project(items[0])
wr.sinks = make([]any, len(fields))
for _, f := range fields {
f.emit(wr)
}
wr.firstPass = false
wr.itemIndex++
for _, item := range items[1:] {
wr.sinkIndex = 0
fields = project(item)
if len(fields) != len(wr.sinks) {
panic(reasonFieldCount)
}
for _, f := range project(item) {
f.emit(wr)
}
wr.itemIndex++
}
return wr.sinks, true
}
type field[E any] struct {
value E
}
// V bundles a field with its static type parameter. V should be called by the
// callback passed to Collect.
func V[E any](value E) Field {
return field[E]{value: value}
}
type writer struct {
sinks []any
sinkIndex int
firstPass bool
itemsLen int
itemIndex int
}
func (f field[E]) emit(wr *writer) {
var sink []E
if wr.firstPass {
sink = make([]E, wr.itemsLen)
wr.sinks[wr.sinkIndex] = sink
} else {
var ok bool
sink, ok = wr.sinks[wr.sinkIndex].([]E)
if !ok {
panic(reasonFieldTypes)
}
}
sink[wr.itemIndex] = f.value
wr.sinkIndex++
}
// A Field is returned by V. See Collect for more details.
type Field interface {
emit(wr *writer)
}
type Fields []Field
type panicReason string
const (
reasonFieldCount panicReason = "soa: inconsistent number of fields"
reasonFieldTypes = "soa: inconsistent type parameter for field"
)
package soa
import (
"slices"
"testing"
)
type user struct {
id int
name string
email string
}
func processUser(u *user) Fields {
return Fields{V(u.id), V(u.name), V(u.email)}
}
// If there are no items in the input, Collect should fail, because it cannot
// know how many columns to return and of what types.
func TestEmptyInput(t *testing.T) {
for _, test := range []struct {
name string
users []*user
}{
{name: "nil", users: nil},
{name: "allocated", users: make([]*user, 0, 4)},
} {
t.Run(test.name, func(t *testing.T) {
result, ok := Collect(test.users, processUser)
if len(result) != 0 || ok {
t.Errorf("got %v, ok=%v; want empty, ok=false", result, ok)
}
})
}
}
// If there are items in the input but the callback does not emit any columns,
// that's fine; this is not an error.
func TestNoFields(t *testing.T) {
users := makeUsers()
result, ok := Collect(users, func(u *user) Fields { return nil })
if len(result) > 0 || !ok {
t.Fatalf("got %v, ok=%v; want empty, ok=true", result, ok)
}
}
func makeUsers() []*user {
return []*user{
{id: 12, name: "Alice", email: "[email protected]"},
{id: 34, name: "Bob", email: "[email protected]"},
{id: 56, name: "Camille", email: "[email protected]"},
{id: 78, name: "Divya", email: "[email protected]"},
}
}
func checkUsersSoA(t *testing.T, users []*user, result []any) {
if len(result) != 3 {
t.Errorf("result: want len=3, got: %v", result)
return
}
// Run the SoA transform manually to compute expected values.
wantIDs := make([]int, len(users))
wantNames := make([]string, len(users))
wantEmails := make([]string, len(users))
for i, user := range users {
wantIDs[i] = user.id
wantNames[i] = user.name
wantEmails[i] = user.email
}
checkSlicesEqual(t, "IDs", result[0], wantIDs)
checkSlicesEqual(t, "names", result[1], wantNames)
checkSlicesEqual(t, "emails", result[2], wantEmails)
}
// Checks that gotAny is a []T whose elements equal those in want.
func checkSlicesEqual[T comparable](t *testing.T, prefix string, gotAny any, want []T) {
got, ok := gotAny.([]T)
if !ok {
t.Errorf("%s: got %T, want %T: got %v, want %v", prefix, gotAny, want, gotAny, want)
return
}
if !slices.Equal(got, want) {
t.Errorf("%s: got %v, want %v", prefix, got, want)
}
}
func TestHappyPath(t *testing.T) {
users := makeUsers()
result, ok := Collect(users, processUser)
if !ok {
t.Errorf("got !ok, want ok")
}
checkUsersSoA(t, users, result)
}
func TestSingleItem(t *testing.T) {
users := makeUsers()[:1]
result, ok := Collect(users, processUser)
if !ok {
t.Errorf("got !ok, want ok")
}
checkUsersSoA(t, users, result)
}
// TestFromScalar makes sure that the types are set up such that you can run
// Collect even on values of non-pointer type.
func TestFromScalar(t *testing.T) {
ints := []int64{1, 2, 3}
result, ok := Collect(ints, func(n int64) Fields {
return Fields{V(n), V(n * n), V(n * n * n), V(n * n * n * n)}
})
if !ok {
t.Errorf("got !ok, want ok")
}
if len(result) != 4 {
t.Fatalf("result: want len=4, got: %v", result)
return
}
checkSlicesEqual(t, "originals", result[0], []int64{1, 2, 3})
checkSlicesEqual(t, "squares", result[1], []int64{1, 4, 9})
checkSlicesEqual(t, "cubes", result[2], []int64{1, 8, 27})
checkSlicesEqual(t, "hypercubes", result[3], []int64{1, 16, 81})
}
// getPanicReason runs f. If f panics, it returns the panic reason (the
// argument passed to `panic`). Otherwise, it returns nil.
func getPanicReason(f func()) (reason any) {
defer func() {
reason = recover()
}()
f()
return
}
func TestPanicIfFieldsIncrease(t *testing.T) {
reason := getPanicReason(func() {
first := true
Collect([]int{1, 2}, func(_ int) Fields {
if first {
first = false
return Fields{V("a")}
} else {
return Fields{V("a"), V("b")}
}
})
})
if reason != reasonFieldCount {
t.Errorf("panic reason: got %v, want %v", reason, reasonFieldCount)
}
}
func TestPanicIfFieldsDecrease(t *testing.T) {
reason := getPanicReason(func() {
first := true
Collect([]int{1, 2}, func(_ int) Fields {
if first {
first = false
return Fields{V("a"), V("b")}
} else {
return Fields{V("a")}
}
})
})
if reason != reasonFieldCount {
t.Errorf("panic reason: got %v, want %v", reason, reasonFieldCount)
}
}
func TestPanicIfFieldsChangeStaticType(t *testing.T) {
reason := getPanicReason(func() {
first := true
Collect([]int{1, 2}, func(_ int) Fields {
if first {
first = false
return Fields{V(any("a"))}
} else {
return Fields{V("b")}
}
})
})
if reason != reasonFieldTypes {
t.Errorf("panic reason: got %v, want %v", reason, reasonFieldTypes)
}
}
type intBox int
type stringBox string
type box interface {
Unbox() any
}
func (i intBox) Unbox() any {
return int(i)
}
func (s stringBox) Unbox() any {
return string(s)
}
// The concrete type of a field is allowed to differ across invocations, as
// long as the static type is the same.
func TestRuntimeTypeMayDiffer(t *testing.T) {
type s struct {
n int
s string
useInt bool
}
items := []s{
{n: 1, s: "a", useInt: false},
{n: 2, s: "b", useInt: true},
{n: 3, s: "c", useInt: true},
{n: 4, s: "d", useInt: false},
}
result, ok := Collect(items, func(item s) Fields {
var f box
if item.useInt {
f = intBox(item.n)
} else {
f = stringBox(item.s)
}
return Fields{V(f)}
})
if !ok {
t.Errorf("got !ok, want ok")
}
want := []box{stringBox("a"), intBox(2), intBox(3), stringBox("d")}
checkSlicesEqual(t, "values", result[0], want)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment