Last active
February 27, 2022 14:26
-
-
Save arashout/3f3bad0bf3d70dc70a8e4b6fec568313 to your computer and use it in GitHub Desktop.
Inject Context Code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func changedFn() { | |
fmt.Println("Nothing to do here") | |
} | |
// TODO: Change this function to make a downstream call which needs a context.Context. | |
// func changedFn(ctx context.Context) { | |
// fmt.Println("Do some important work...") | |
// // Now also make a DB call | |
// makeDBCall(ctx, "Some important data!") | |
// } | |
// NOTE: | |
// This Context must flow down from the very start of the program! | |
// Which means the functions that call this function must also have a context.Context parameter now |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"fmt" | |
"github.com/dave/dst" | |
"github.com/dave/dst/decorator" | |
"github.com/dave/dst/dstutil" | |
) | |
... | |
// FormatNode converts the ast representation into it's textual format (Basically actual Go code) | |
func FormatNode(file dst.File) string { | |
var buf bytes.Buffer | |
decorator.Fprint(&buf, &file) | |
return buf.String() | |
} | |
func main() { | |
file, err := decorator.Parse(srcCodeString) | |
must(err) | |
applyFunc := func(c *dstutil.Cursor) bool { | |
node := c.Node() | |
switch n := node.(type) { | |
// Add an (additional) "ctx context.Context" parameter to EVERY function definition | |
case (*dst.FuncDecl): | |
ctxP := newCtxParam() | |
// We "prepend" the "ctx" parameter so it is the first parameter in the function call | |
n.Type.Params.List = append([]*dst.Field{&ctxP}, n.Type.Params.List...) | |
// Add an (additional) "ctx" argument to EVERY function call | |
case (*dst.CallExpr): | |
ctxA := newCtxArg() | |
// We "prepend" the "ctx" argument so it is the first argument in the function call | |
n.Args = append([]dst.Expr{&ctxA}, n.Args...) | |
} | |
return true | |
} | |
// We traverse the Go AST via the Apply function | |
// If the node is "nil" or the return value is "false" traversal stops | |
// Lastly, it's possible to edit the AST while doing the traversal and return the result | |
_ = dstutil.Apply(file, applyFunc, nil) | |
fmt.Println(FormatNode(*file)) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package test | |
import ( | |
"fmt" | |
"context" | |
) | |
// Added new context.Context parameter for downstream call | |
func changedFn(ctx context.Context) { | |
fmt.Println("Do some important work...") | |
// Now also make a downstream call | |
makeDownstreamRequest(ctx, "Some important data!") | |
} | |
// TODO: func needsctx1(ctx context.Context, n int) | |
func needsctx1(n int) { | |
if true { | |
// TODO: changedFn(ctx) | |
changedFn() | |
} | |
} | |
// TODO: func needsctx2(ctx context.Context) bool | |
func needsctx2() bool { | |
for index := 0; index < 3; index++ { | |
needsctx1(ctx, 1) | |
} | |
return true | |
} | |
// TODO: func needsctx3(ctx context.Context) | |
func needsctx3() { | |
if needsctx2(ctx) { | |
changedFn(ctx) | |
} | |
} | |
type SS struct{} | |
// TODO: func (rec *SS) save(ctx context.Context, s string, n int) | |
func (rec *SS) save(s string, n int) { | |
// TODO: needsctx1(ctx, 2) | |
needsctx1(2) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
applyFunc := func(c *dstutil.Cursor) bool { | |
node := c.Node() | |
switch n := node.(type) { | |
// Add an (additional) "ctx context.Context" parameter to EVERY function definition | |
case (*dst.FuncDecl): | |
ctxP := newCtxParam() | |
// We "prepend" the "ctx" parameter so it is the first parameter in the function call | |
n.Type.Params.List = append([]*dst.Field{&ctxP}, n.Type.Params.List...) | |
// Add an (additional) "ctx" argument to EVERY function call | |
case (*dst.CallExpr): | |
ctxA := newCtxArg() | |
// We "prepend" the "ctx" argument so it is the first argument in the function call | |
n.Args = append([]dst.Expr{&ctxA}, n.Args...) | |
} | |
return true | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package test | |
import ( | |
"fmt" | |
) | |
type SS struct{} | |
func (rec *SS) save(ctx context.Context, s string, n int) { | |
needsctx1(ctx, 2) | |
} | |
func needsctx3(ctx context.Context) { | |
if needsctx2(ctx) { | |
changedFn(ctx) | |
} | |
} | |
func needsctx2(ctx context.Context) bool { | |
for index := 0; index < 3; index++ { | |
needsctx1(ctx, 1) | |
} | |
return true | |
} | |
func needsctx1(ctx context.Context, n int) { | |
if true { | |
changedFn(ctx) | |
} | |
} | |
func changedFn(ctx context.Context) { | |
fmt.Printf("Print the context for some reason! I'ts important: %+v", ctx) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func hasContextParam(fd *dst.FuncDecl) bool { | |
// 1. Check if a context is already passed as parameter, if so return early | |
for _, p := range fd.Type.Params.List { | |
// If it's not a *SelectorExpr, then skip it (e.g. "a" and "b" in func(a, b, c string)) | |
se, ok := p.Type.(*dst.SelectorExpr) | |
if !ok { | |
continue | |
} | |
if se.Sel.Name == "Context" { | |
return true | |
} | |
} | |
return false | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func newCtxParam() dst.Field { | |
return dst.Field{ | |
Names: []*dst.Ident{&dst.Ident{Name: "ctx"}}, | |
Type: &dst.SelectorExpr{ | |
X: &dst.Ident{Name: "context"}, | |
Sel: &dst.Ident{Name: "Context"}, | |
}, | |
} | |
} | |
func newCtxArg() dst.Ident { | |
return dst.Ident{Name: "ctx"} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
prev := startCode | |
i := 0 | |
file, err := decorator.Parse(prev) | |
curN := dstutil.Apply(file, applyFunc, nil) | |
must(err) | |
for { | |
i++ | |
curN = dstutil.Apply(curN, applyFunc, nil) | |
cur = FormatNode(*file) | |
if cur == prev { | |
break | |
} | |
prev = cur | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func changedFn(ctx context.Context) { | |
fmt.Println("Do some important work...") | |
// Now also make a downstream call | |
makeDBCall(ctx, "Some important data!") | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func changedFn() { | |
fmt.Println("Nothing to do here") | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"go/parser" | |
"go/token" | |
"github.com/dave/dst" | |
"github.com/dave/dst/decorator" | |
"github.com/dave/dst/dstutil" | |
) | |
// Utility error checking function for when you don't need to gracefully handle errors | |
func must(err error) { | |
if err != nil { | |
panic(err) | |
} | |
} | |
func main() { | |
file, err := decorator.Parse(srcCodeString) | |
must(err) | |
// Notice that we have to define our own function examining/editting a node during AST traversal | |
applyFunc := func(c *dstutil.Cursor) bool { | |
node := c.Node() | |
// Use a switch-case construct based on the node "type" | |
// This is a very useful of navigating the AST | |
switch n := node.(type) { | |
case (*dst.FuncDecl): | |
// Pretty print the Node AST | |
dst.Print(n) | |
} | |
return true | |
} | |
// We traverse the Go AST via the Apply function | |
// If the node is "nil" or the return value is "false" traversal stops | |
// Lastly, it's possible to edit the AST while doing the traversal and return the result | |
_ = dstutil.Apply(file, applyFunc, nil) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package test | |
import ( | |
"fmt" | |
) | |
type SS struct{} | |
// Added on 'ctx context.Context' parameter on iteration 4 | |
func (rec *SS) save(ctx context.Context, s string, n int) { | |
needsctx1(ctx, 2) // Added on ctx arg on iteration 3 | |
} | |
// Added on 'ctx context.Context' parameter on iteration 2 | |
func needsctx3(ctx context.Context) { | |
if needsctx2(ctx) { // Added on ctx arg on iteration 5 | |
changedFn(ctx) // Added on ctx arg on iteration 1 | |
} | |
} | |
// Added on 'ctx context.Context' parameter on iteration 4 | |
func needsctx2(ctx context.Context) bool { | |
for index := 0; index < 3; index++ { | |
needsctx1(ctx, 1) // Added on ctx arg on iteration 3 | |
} | |
return true | |
} | |
// Added on 'ctx context.Context' parameter on iteration 2 | |
func needsctx1(ctx context.Context, n int) { | |
if true { | |
changedFn(ctx) // Added on ctx arg on iteration 1 | |
} | |
} | |
// Added on 'ctx context.Context' parameter on iteration 0 | |
func changedFn(ctx context.Context) { | |
fmt.Printf("Print the context for some reason! I'ts important: %+v", ctx) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
applyFunc := func(c *dstutil.Cursor) bool { | |
node := c.Node() | |
switch n := node.(type) { | |
// Look into function body to see if call like fn(ctx, ...) if so add "ctx context.Context" parameter | |
case (*dst.FuncDecl): | |
if doesFuncDeclRequireCtx(n) { | |
needsContextFuncs[n.Name.Name] = true | |
ctxP := newCtxParam() | |
n.Type.Params.List = append([]*dst.Field{&ctxP}, n.Type.Params.List...) | |
} | |
// From populated map based on previous case, check if a function NOW needs a ctx argument | |
case (*dst.CallExpr): | |
ident, ok := n.Fun.(*dst.Ident) | |
if ok && !hasCtxArg(n) && needsContextFuncs[ident.Name] { | |
ctxA := newCtxArg() | |
n.Args = append([]dst.Expr{&ctxA}, n.Args...) | |
} | |
} | |
return true | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
var needsContextFuncs = make(map[string]bool) | |
// looks into function body to see if call like fn(ctx, ...) if so add "ctx context.Context" parameter | |
func doesFuncDeclRequireCtx(fd *dst.FuncDecl) bool { | |
// 1. Check if a context is already passed as parameter, if so return early | |
for _, p := range fd.Type.Params.List { | |
se, ok := p.Type.(*dst.SelectorExpr) | |
if !ok { | |
continue | |
} | |
if se.Sel.Name == "Context" { | |
return false | |
} | |
} | |
// 2. If it doesn't check if the function body has a CallExpr with "ctx" argument | |
needsCtx := false | |
dst.Inspect(fd.Body, func(node dst.Node) bool { | |
switch n := node.(type) { | |
case (*dst.CallExpr): | |
if hasCtxArg(n) { | |
needsCtx = true | |
return false | |
} | |
} | |
return true | |
}) | |
return needsCtx | |
} | |
func hasCtxArg(ce *dst.CallExpr) bool { | |
for _, arg := range ce.Args { | |
switch v := arg.(type) { | |
case (*dst.Ident): | |
if v.Name == "ctx" { | |
return true | |
} | |
} | |
} | |
return false | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment