Skip to content

Instantly share code, notes, and snippets.

@arashout
Last active February 27, 2022 14:26
Show Gist options
  • Save arashout/3f3bad0bf3d70dc70a8e4b6fec568313 to your computer and use it in GitHub Desktop.
Save arashout/3f3bad0bf3d70dc70a8e4b6fec568313 to your computer and use it in GitHub Desktop.
Inject Context Code
func changedFn() {
fmt.Println("Nothing to do here")
}
// TODO: Change this function to make a downstream call which needs a context.Context.
// func changedFn(ctx context.Context) {
// fmt.Println("Do some important work...")
// // Now also make a DB call
// makeDBCall(ctx, "Some important data!")
// }
// NOTE:
// This Context must flow down from the very start of the program!
// Which means the functions that call this function must also have a context.Context parameter now
package main
import (
"bytes"
"fmt"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/dstutil"
)
...
// FormatNode converts the ast representation into it's textual format (Basically actual Go code)
func FormatNode(file dst.File) string {
var buf bytes.Buffer
decorator.Fprint(&buf, &file)
return buf.String()
}
func main() {
file, err := decorator.Parse(srcCodeString)
must(err)
applyFunc := func(c *dstutil.Cursor) bool {
node := c.Node()
switch n := node.(type) {
// Add an (additional) "ctx context.Context" parameter to EVERY function definition
case (*dst.FuncDecl):
ctxP := newCtxParam()
// We "prepend" the "ctx" parameter so it is the first parameter in the function call
n.Type.Params.List = append([]*dst.Field{&ctxP}, n.Type.Params.List...)
// Add an (additional) "ctx" argument to EVERY function call
case (*dst.CallExpr):
ctxA := newCtxArg()
// We "prepend" the "ctx" argument so it is the first argument in the function call
n.Args = append([]dst.Expr{&ctxA}, n.Args...)
}
return true
}
// We traverse the Go AST via the Apply function
// If the node is "nil" or the return value is "false" traversal stops
// Lastly, it's possible to edit the AST while doing the traversal and return the result
_ = dstutil.Apply(file, applyFunc, nil)
fmt.Println(FormatNode(*file))
}
package test
import (
"fmt"
"context"
)
// Added new context.Context parameter for downstream call
func changedFn(ctx context.Context) {
fmt.Println("Do some important work...")
// Now also make a downstream call
makeDownstreamRequest(ctx, "Some important data!")
}
// TODO: func needsctx1(ctx context.Context, n int)
func needsctx1(n int) {
if true {
// TODO: changedFn(ctx)
changedFn()
}
}
// TODO: func needsctx2(ctx context.Context) bool
func needsctx2() bool {
for index := 0; index < 3; index++ {
needsctx1(ctx, 1)
}
return true
}
// TODO: func needsctx3(ctx context.Context)
func needsctx3() {
if needsctx2(ctx) {
changedFn(ctx)
}
}
type SS struct{}
// TODO: func (rec *SS) save(ctx context.Context, s string, n int)
func (rec *SS) save(s string, n int) {
// TODO: needsctx1(ctx, 2)
needsctx1(2)
}
applyFunc := func(c *dstutil.Cursor) bool {
node := c.Node()
switch n := node.(type) {
// Add an (additional) "ctx context.Context" parameter to EVERY function definition
case (*dst.FuncDecl):
ctxP := newCtxParam()
// We "prepend" the "ctx" parameter so it is the first parameter in the function call
n.Type.Params.List = append([]*dst.Field{&ctxP}, n.Type.Params.List...)
// Add an (additional) "ctx" argument to EVERY function call
case (*dst.CallExpr):
ctxA := newCtxArg()
// We "prepend" the "ctx" argument so it is the first argument in the function call
n.Args = append([]dst.Expr{&ctxA}, n.Args...)
}
return true
}
package test
import (
"fmt"
)
type SS struct{}
func (rec *SS) save(ctx context.Context, s string, n int) {
needsctx1(ctx, 2)
}
func needsctx3(ctx context.Context) {
if needsctx2(ctx) {
changedFn(ctx)
}
}
func needsctx2(ctx context.Context) bool {
for index := 0; index < 3; index++ {
needsctx1(ctx, 1)
}
return true
}
func needsctx1(ctx context.Context, n int) {
if true {
changedFn(ctx)
}
}
func changedFn(ctx context.Context) {
fmt.Printf("Print the context for some reason! I'ts important: %+v", ctx)
}
func hasContextParam(fd *dst.FuncDecl) bool {
// 1. Check if a context is already passed as parameter, if so return early
for _, p := range fd.Type.Params.List {
// If it's not a *SelectorExpr, then skip it (e.g. "a" and "b" in func(a, b, c string))
se, ok := p.Type.(*dst.SelectorExpr)
if !ok {
continue
}
if se.Sel.Name == "Context" {
return true
}
}
return false
}
func newCtxParam() dst.Field {
return dst.Field{
Names: []*dst.Ident{&dst.Ident{Name: "ctx"}},
Type: &dst.SelectorExpr{
X: &dst.Ident{Name: "context"},
Sel: &dst.Ident{Name: "Context"},
},
}
}
func newCtxArg() dst.Ident {
return dst.Ident{Name: "ctx"}
}
prev := startCode
i := 0
file, err := decorator.Parse(prev)
curN := dstutil.Apply(file, applyFunc, nil)
must(err)
for {
i++
curN = dstutil.Apply(curN, applyFunc, nil)
cur = FormatNode(*file)
if cur == prev {
break
}
prev = cur
}
func changedFn(ctx context.Context) {
fmt.Println("Do some important work...")
// Now also make a downstream call
makeDBCall(ctx, "Some important data!")
}
func changedFn() {
fmt.Println("Nothing to do here")
}
package main
import (
"go/parser"
"go/token"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/dstutil"
)
// Utility error checking function for when you don't need to gracefully handle errors
func must(err error) {
if err != nil {
panic(err)
}
}
func main() {
file, err := decorator.Parse(srcCodeString)
must(err)
// Notice that we have to define our own function examining/editting a node during AST traversal
applyFunc := func(c *dstutil.Cursor) bool {
node := c.Node()
// Use a switch-case construct based on the node "type"
// This is a very useful of navigating the AST
switch n := node.(type) {
case (*dst.FuncDecl):
// Pretty print the Node AST
dst.Print(n)
}
return true
}
// We traverse the Go AST via the Apply function
// If the node is "nil" or the return value is "false" traversal stops
// Lastly, it's possible to edit the AST while doing the traversal and return the result
_ = dstutil.Apply(file, applyFunc, nil)
}
package test
import (
"fmt"
)
type SS struct{}
// Added on 'ctx context.Context' parameter on iteration 4
func (rec *SS) save(ctx context.Context, s string, n int) {
needsctx1(ctx, 2) // Added on ctx arg on iteration 3
}
// Added on 'ctx context.Context' parameter on iteration 2
func needsctx3(ctx context.Context) {
if needsctx2(ctx) { // Added on ctx arg on iteration 5
changedFn(ctx) // Added on ctx arg on iteration 1
}
}
// Added on 'ctx context.Context' parameter on iteration 4
func needsctx2(ctx context.Context) bool {
for index := 0; index < 3; index++ {
needsctx1(ctx, 1) // Added on ctx arg on iteration 3
}
return true
}
// Added on 'ctx context.Context' parameter on iteration 2
func needsctx1(ctx context.Context, n int) {
if true {
changedFn(ctx) // Added on ctx arg on iteration 1
}
}
// Added on 'ctx context.Context' parameter on iteration 0
func changedFn(ctx context.Context) {
fmt.Printf("Print the context for some reason! I'ts important: %+v", ctx)
}
applyFunc := func(c *dstutil.Cursor) bool {
node := c.Node()
switch n := node.(type) {
// Look into function body to see if call like fn(ctx, ...) if so add "ctx context.Context" parameter
case (*dst.FuncDecl):
if doesFuncDeclRequireCtx(n) {
needsContextFuncs[n.Name.Name] = true
ctxP := newCtxParam()
n.Type.Params.List = append([]*dst.Field{&ctxP}, n.Type.Params.List...)
}
// From populated map based on previous case, check if a function NOW needs a ctx argument
case (*dst.CallExpr):
ident, ok := n.Fun.(*dst.Ident)
if ok && !hasCtxArg(n) && needsContextFuncs[ident.Name] {
ctxA := newCtxArg()
n.Args = append([]dst.Expr{&ctxA}, n.Args...)
}
}
return true
}
var needsContextFuncs = make(map[string]bool)
// looks into function body to see if call like fn(ctx, ...) if so add "ctx context.Context" parameter
func doesFuncDeclRequireCtx(fd *dst.FuncDecl) bool {
// 1. Check if a context is already passed as parameter, if so return early
for _, p := range fd.Type.Params.List {
se, ok := p.Type.(*dst.SelectorExpr)
if !ok {
continue
}
if se.Sel.Name == "Context" {
return false
}
}
// 2. If it doesn't check if the function body has a CallExpr with "ctx" argument
needsCtx := false
dst.Inspect(fd.Body, func(node dst.Node) bool {
switch n := node.(type) {
case (*dst.CallExpr):
if hasCtxArg(n) {
needsCtx = true
return false
}
}
return true
})
return needsCtx
}
func hasCtxArg(ce *dst.CallExpr) bool {
for _, arg := range ce.Args {
switch v := arg.(type) {
case (*dst.Ident):
if v.Name == "ctx" {
return true
}
}
}
return false
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment