Skip to content

Instantly share code, notes, and snippets.

@Cyberax
Created July 26, 2021 04:52
Show Gist options
  • Save Cyberax/eb42d249d022c55ce9dc6572309200ce to your computer and use it in GitHub Desktop.
Save Cyberax/eb42d249d022c55ce9dc6572309200ce to your computer and use it in GitHub Desktop.
Reflective AwsMocks
package utils
import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/smithy-go/middleware"
"reflect"
)
type AwsMockHandler struct {
handlers []reflect.Value
functors []reflect.Value
}
// NewAwsMockHandler - Create an AWS mocker to use with the AWS services, it returns an instrumented
// aws.Config that can be used to create AWS services.
// You can add as many individual request handlers as you need, as long as handlers
// correspond to the func(context.Context, <arg>)(<res>, error) format.
// E.g.:
// func(context.Context, *ec2.TerminateInstancesInput)(*ec2.TerminateInstancesOutput, error)
//
// You can also use a struct as the handler, in this case the AwsMockHandler will try
// to search for a method with a conforming signature.
func NewAwsMockHandler() *AwsMockHandler {
return &AwsMockHandler{}
}
type retargetingHandler struct {
parent *AwsMockHandler
}
func (f *retargetingHandler) ID() string {
return "ShortCircuitRequest"
}
type initialRequestKey struct{}
func (f *retargetingHandler) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput,
next middleware.DeserializeHandler) (out middleware.DeserializeOutput, metadata middleware.Metadata, err error) {
req := ctx.Value(&initialRequestKey{})
out.Result, err = f.parent.invokeMethod(ctx, req)
return
}
type saveRequestMiddleware struct {
}
func (s saveRequestMiddleware) ID() string {
return "OriginalRequestSaver"
}
func (s saveRequestMiddleware) HandleInitialize(ctx context.Context, in middleware.InitializeInput,
next middleware.InitializeHandler) (out middleware.InitializeOutput, metadata middleware.Metadata, err error) {
return next.HandleInitialize(context.WithValue(ctx, &initialRequestKey{}, in.Parameters), in)
}
func (a *AwsMockHandler) AwsConfig() aws.Config {
cfg := aws.NewConfig()
cfg.Region = "us-mars-1"
cfg.APIOptions = []func(*middleware.Stack) error{
func(stack *middleware.Stack) error {
// We leave the serialization middleware intact in the vain hope that
// AWS re-adds validation to serialization.
//stack.Initialize.Clear()
//stack.Serialize.Clear()
// Make sure to save the initial non-serialized request
_ = stack.Initialize.Add(&saveRequestMiddleware{}, middleware.Before)
// Clear all the other middleware
stack.Build.Clear()
stack.Finalize.Clear()
stack.Deserialize.Clear()
// And replace the last one with our special middleware that dispatches
// the request to our handlers
_ = stack.Deserialize.Add(&retargetingHandler{parent: a}, middleware.Before)
return nil
},
}
return *cfg
}
func (a *AwsMockHandler) AddHandler(handlerObject interface {}) {
handler := reflect.ValueOf(handlerObject)
tp := handler.Type()
if handler.Kind() == reflect.Func {
PanicIfF(tp.NumOut() != 2 || tp.NumIn() != 2,
"handler must have signature of func(context.Context, <arg>)(<res>, error)")
a.functors = append(a.functors, handler)
} else {
PanicIfF(tp.NumMethod() == 0, "the handler must have invokable methods")
a.handlers = append(a.handlers, handler)
}
}
func (a *AwsMockHandler) invokeMethod(ctx context.Context,
params interface{}) (interface{}, error) {
for _, h := range a.handlers {
for i := 0; i < h.NumMethod(); i++ {
method := h.Method(i)
matched, res, err := tryInvoke(ctx, params, method)
if matched {
return res, err
}
}
}
for _, f := range a.functors {
matched, res, err := tryInvoke(ctx, params, f)
if matched {
return res, err
}
}
panic("could not find a handler for operation: " + awsmiddleware.GetOperationName(ctx))
}
func tryInvoke(ctx context.Context, params interface{}, method reflect.Value) (
bool, interface{}, error) {
paramType := reflect.TypeOf(params)
errorType := reflect.TypeOf((*error)(nil)).Elem()
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
methodDesc := method.Type()
if methodDesc.NumIn() != 2 || methodDesc.NumOut() != 2 {
return false, nil, nil
}
if !contextType.ConvertibleTo(methodDesc.In(0)) {
return false, nil, nil
}
if !paramType.ConvertibleTo(methodDesc.In(1)) {
return false, nil, nil
}
if !methodDesc.Out(1).ConvertibleTo(errorType) {
return false, nil, nil
}
// It's our target!
res := method.Call([]reflect.Value{reflect.ValueOf(ctx),
reflect.ValueOf(params)})
if !res[1].IsNil() {
return true, nil, res[1].Interface().(error)
}
return true, res[0].Interface(), nil
}
package utils
import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/smithy-go"
"github.com/stretchr/testify/assert"
"testing"
)
type tester struct {
}
//noinspection GoUnusedParameter
func (t *tester) TerminateInstances(ctx context.Context,
input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error) {
return nil, smithy.NewErrParamRequired("something")
}
//noinspection GoUnusedParameter
func (t *tester) AlmostRunDescribe1(ctx context.Context,
input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, string) {
return nil, ""
}
//noinspection GoUnusedParameter
func (t *tester) AlmostRunDescribe2(input *ec2.DescribeInstancesInput, _ string) (
*ec2.DescribeInstancesOutput, error) {
return nil, nil
}
//noinspection GoUnusedParameter
func (t *tester) AlmostRunDescribe3(ctx context.Context,
input *ec2.DescribeAccountAttributesInput) (*ec2.DescribeInstancesOutput, error) {
return nil, nil
}
//noinspection GoUnusedParameter
func (t *tester) AlmostRunDescribe4(input *ec2.DescribeInstancesInput) (
*ec2.DescribeInstancesOutput, error) {
return nil, nil
}
//noinspection GoUnusedParameter
func (t *tester) AlmostRunDescribe5(ctx context.Context,
input *ec2.DescribeInstancesInput) error {
return nil
}
func TestMockNotFound(t *testing.T) {
am := AwsMockHandler{}
am.AddHandler(&tester{})
assert.Panics(t, func() {
ec := ec2.NewFromConfig(am.AwsConfig())
_, _ = ec.DeleteKeyPair(context.Background(), &ec2.DeleteKeyPairInput{
KeyName: aws.String("something"),
})
}, "could not find a handler for operation: DeleteKeyPair")
}
func TestAwsMock(t *testing.T) {
am := NewAwsMockHandler()
am.AddHandler(&tester{})
am.AddHandler(func(ctx context.Context, arg *ec2.DescribeInstancesInput) (
*ec2.DescribeInstancesOutput, error) {
return &ec2.DescribeInstancesOutput{NextToken: arg.NextToken}, nil
})
am.AddHandler(func(ctx context.Context, arg *ec2.TerminateInstancesInput) (
*ec2.DescribeInstancesOutput, error) {
return nil, nil
})
ec := ec2.NewFromConfig(am.AwsConfig())
response, e := ec.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{
NextToken: aws.String("hello, token"),
})
assert.NoError(t, e)
assert.Equal(t, "hello, token", *response.NextToken)
// Check the tester methods
_, err := ec.TerminateInstances(context.Background(), &ec2.TerminateInstancesInput{})
assert.Error(t, err, "something")
}
func ExampleNewAwsMockHandler() {
am := NewAwsMockHandler()
am.AddHandler(func(ctx context.Context, arg *ec2.TerminateInstancesInput) (
*ec2.TerminateInstancesOutput, error) {
if arg.InstanceIds[0] != "i-123" {
panic("BadInstanceId")
}
return &ec2.TerminateInstancesOutput{}, nil
})
ec := ec2.NewFromConfig(am.AwsConfig())
_, _ = ec.TerminateInstances(context.Background(), &ec2.TerminateInstancesInput{
InstanceIds: []string{"i-123"},
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment