Skip to content

Instantly share code, notes, and snippets.

@SergeAx
Created August 13, 2025 23:04
Show Gist options
  • Save SergeAx/fa7d17101a7d591164dec5df42f4a561 to your computer and use it in GitHub Desktop.
Save SergeAx/fa7d17101a7d591164dec5df42f4a561 to your computer and use it in GitHub Desktop.
Wrapping Dig container and testing it
package app
import (
"fmt"
"log"
"reflect"
"runtime"
"go.uber.org/dig"
)
type Container struct {
*dig.Container
testMode bool
}
func (c *Container) mustProvide(constructor any, opts ...dig.ProvideOption) {
if err := c.Provide(constructor, opts...); err != nil {
constructorName := runtime.FuncForPC(reflect.ValueOf(constructor).Pointer()).Name()
FailOnError(err, "dig.Container.Provide("+constructorName+") error")
}
if !c.testMode {
return
}
// In test mode we additionally check if constructor returns pointer
constructorName := runtime.FuncForPC(reflect.ValueOf(constructor).Pointer()).Name()
returnKind := reflect.TypeOf(constructor).Out(0).Kind()
if returnKind != reflect.Ptr && returnKind != reflect.Interface {
FailOnError(fmt.Errorf("non-pointer return type"), "dig.Container.Provide("+constructorName+") error")
}
}
func FailOnError(err error, msg string, params ...any) {
if err == nil {
return
}
if rootCause := dig.RootCause(err); rootCause != nil {
err = rootCause
}
log.Errorf(msg, params...)
log.Errorf("Underlying error: %v", err)
log.Fatal("Process terminated")
}
func NewAppContainer(testMode bool) *Container {
var dc *dig.Container
if !testMode {
dc = dig.New(dig.DeferAcyclicVerification())
} else {
dc = dig.New()
}
c := &Container{
dc,
testMode,
}
//c.mustProvide(...)
//c.mustProvide(...)
//c.mustProvide(...)
return c
}
package app
import (
"os"
"reflect"
"runtime"
"testing"
"go.uber.org/dig"
)
func TestContainer(t *testing.T) {
c := NewAppContainer(true)
testConstructor(t, c, func(_ *App) {})
// testConstructor(t, c, func(_ *Foo) {})
// testConstructor(t, c, func(_ *Bar) {})
// testConstructor(t, c, func(_ *Baz) {})
// ...
}
func testConstructor(t *testing.T, c *Container, constructor any) {
if err := c.Invoke(constructor); err != nil {
constructorName := runtime.FuncForPC(reflect.ValueOf(constructor).Pointer()).Name()
t.Error(constructorName, err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment