Created
November 16, 2020 19:09
-
-
Save tsandall/9934d2b4ed35802da75d684fd3232751 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package wasm | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"flag" | |
"fmt" | |
"testing" | |
"time" | |
"github.com/open-policy-agent/opa/util" | |
"github.com/open-policy-agent/opa/ast" | |
"github.com/open-policy-agent/opa/internal/wasm/sdk/opa" | |
"github.com/open-policy-agent/opa/rego" | |
"github.com/open-policy-agent/opa/test/cases" | |
"github.com/open-policy-agent/opa/types" | |
) | |
var caseDir = flag.String("case-dir", "../cases/testdata", "set directory to load test cases from") | |
func TestE2E(t *testing.T) { | |
addTestSleepBuiltin() | |
ctx := context.Background() | |
for _, tc := range cases.MustLoad(*caseDir).Sorted().Cases { | |
t.Run(tc.Filename, func(t *testing.T) { | |
if tc.WantErrorCode != nil || tc.WantError != nil { | |
t.SkipNow() | |
} | |
opts := []func(*rego.Rego){ | |
rego.Query(tc.Query), | |
} | |
for i := range tc.Modules { | |
opts = append(opts, rego.Module(fmt.Sprintf("module-%d.rego", i), tc.Modules[i])) | |
} | |
cr, err := rego.New(opts...).Compile(ctx) | |
if err != nil { | |
t.Fatal(err) | |
} | |
o := opa.New().WithPolicyBytes(cr.Bytes) | |
if tc.Data != nil { | |
o = o.WithDataJSON(tc.Data) | |
} | |
o, err = o.Init() | |
if err != nil { | |
t.Fatal(err) | |
} | |
var input *interface{} | |
if tc.InputTerm != nil { | |
var x interface{} = ast.MustParseTerm(*tc.InputTerm) | |
input = &x | |
} else if tc.Input != nil { | |
input = tc.Input | |
} | |
result, err := o.Eval(ctx, opa.EvalOpts{Input: input}) | |
assert(t, tc, result, err) | |
}) | |
} | |
} | |
func assert(t *testing.T, tc cases.TestCase, result *opa.Result, err error) { | |
t.Helper() | |
if tc.WantDefined != nil { | |
if err != nil { | |
t.Fatal("unexpected error:", err) | |
} else { | |
assertDefined(t, defined(*tc.WantDefined), result) | |
} | |
} else if tc.WantResult != nil { | |
if err != nil { | |
t.Fatal("unexpected error:", err) | |
} else { | |
assertResultSet(t, *tc.WantResult, tc.SortBindings, result) | |
} | |
} else if tc.WantErrorCode != nil || tc.WantError != nil { | |
if err == nil { | |
t.Fatal("expected error") | |
} | |
t.Log("err:", err) | |
// TODO: implement more specific error checking | |
} | |
} | |
type defined bool | |
func (x defined) String() string { | |
if x { | |
return "defined" | |
} | |
return "undefined" | |
} | |
func assertDefined(t *testing.T, want defined, result *opa.Result) { | |
t.Helper() | |
var rs []interface{} | |
if err := util.NewJSONDecoder(bytes.NewReader(result.Result)).Decode(&rs); err != nil { | |
panic(err) | |
} | |
got := defined(len(rs) > 0) | |
if got != want { | |
t.Fatalf("expected %v but got %v", want, got) | |
} | |
} | |
func assertResultSet(t *testing.T, want []map[string]interface{}, sortBindings bool, result *opa.Result) { | |
t.Helper() | |
a := toAST(want) | |
b := toAST(result.Result) | |
if sortBindings { | |
result := ast.NewArray() | |
a.Value.(*ast.Array).Sorted().Foreach(func(x *ast.Term) { | |
cpy, _ := x.Value.(ast.Object).Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) { | |
return k, ast.NewTerm(v.Value.(*ast.Array).Sorted()), nil | |
}) | |
result.Append(ast.NewTerm(cpy)) | |
}) | |
a.Value = result | |
result = ast.NewArray() | |
b.Value.(ast.Set).Sorted().Foreach(func(x *ast.Term) { | |
cpy, _ := x.Value.(ast.Object).Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) { | |
var sorted *ast.Array | |
switch v := v.Value.(type) { | |
case ast.Set: | |
sorted = v.Sorted() | |
case *ast.Array: | |
sorted = v.Sorted() | |
default: | |
panic("illegal value") | |
} | |
return k, ast.NewTerm(sorted), nil | |
}) | |
result.Append(ast.NewTerm(cpy)) | |
}) | |
b.Value = result | |
} | |
if !a.Equal(b) { | |
t.Fatalf("expected %v but got %v", a, b) | |
} | |
} | |
func toAST(a interface{}) *ast.Term { | |
if bs, ok := a.([]byte); ok { | |
return ast.MustParseTerm(string(bs)) | |
} | |
buf := bytes.NewBuffer(nil) | |
if err := json.NewEncoder(buf).Encode(a); err != nil { | |
panic(err) | |
} | |
return ast.MustParseTerm(buf.String()) | |
} | |
func addTestSleepBuiltin() { | |
rego.RegisterBuiltin1(®o.Function{ | |
Name: "test.sleep", | |
Decl: types.NewFunction(types.Args(types.S), types.NewNull()), | |
}, func(_ rego.BuiltinContext, op *ast.Term) (*ast.Term, error) { | |
d, _ := time.ParseDuration(string(op.Value.(ast.String))) | |
time.Sleep(d) | |
return ast.NullTerm(), nil | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment