Skip to content

Instantly share code, notes, and snippets.

@tcodes0
Created August 23, 2023 16:09
Show Gist options
  • Save tcodes0/4d7b25940625b575827c37049e06df37 to your computer and use it in GitHub Desktop.
Save tcodes0/4d7b25940625b575827c37049e06df37 to your computer and use it in GitHub Desktop.
/* a generic reflection based wrapper around common testing boilerplate
messages.On("Transaction", ctx, mock.AnythingOfType("func(storage.MessageStorer) error")).
Run(func(args mock.Arguments) {
err := args.Get(1).(func(tx storage.MessageStorer) error)(tx)
assert.NoError(err)
}).
Return(nil)
usage
messages := *storage.MockMessageStorer{}
tx := storage.MockMessageStorer{}
err = Transaction[storage.MessageStorer](ctx, messages, tx)
assert.NoError(err)
does it work? no
*/
func Transaction[T any](ctx context.Context, mockStore T, mockTXStore T) error {
rt := reflect.TypeOf(mockStore)
if rt.Kind() != reflect.Pointer {
// store should be *mock.Mock which contain mutexes that need to be passed by reference
return errors.New("store is not a pointer")
}
on := reflect.ValueOf(mockStore).MethodByName("On")
if on.IsZero() {
return errors.New("expected On method")
}
args := make([]reflect.Value, 3)
args[0] = reflect.ValueOf("Transaction")
args[1] = reflect.ValueOf(ctx)
// trick to get T as string
txFnType := fmt.Sprintf("func(%s) error", reflect.TypeOf(new(T)).Elem().String())
args[2] = reflect.ValueOf(mock.AnythingOfType(txFnType))
res := on.Call(args)
if len(res) < 1 {
return errors.New("unexpected On return")
}
run := res[0].MethodByName("Run")
if run.IsZero() {
return errors.New("expected Run method")
}
var err error
fn := func(args mock.Arguments) {
err = args.Get(1).(func(tx T) error)(mockTXStore)
}
res = run.Call([]reflect.Value{reflect.ValueOf(fn)})
if len(res) < 1 {
return errors.New("unexpected Run return")
}
if err != nil {
return errs.Wrap(err, "error running transaction")
}
ret := res[0].MethodByName("Return")
if ret.IsZero() {
return errors.New("expected Return method")
}
res = ret.Call([]reflect.Value{reflect.New(reflect.TypeOf(new(error)).Elem())})
if len(res) < 1 {
return errors.New("unexpected Return return")
}
return nil
}
@tcodes0
Copy link
Author

tcodes0 commented Aug 24, 2023

new(error) -> new(ErrCode)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment