Created
September 20, 2023 19:49
-
-
Save chrisseto/cd5f94c7e70cbbccd9df05788e4b1cb8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"fmt" | |
"go/ast" | |
"go/token" | |
"go/types" | |
"sort" | |
"strings" | |
"golang.org/x/tools/go/packages" | |
) | |
func MethodByName(t *types.Named, name string) *types.Func { | |
for i := 0; i < t.NumMethods(); i++ { | |
method := t.Method(i) | |
if method.Name() == name { | |
return method | |
} | |
} | |
return nil | |
} | |
func Filter[T any](items []T, filters ...func(T) bool) []T { | |
all := func(item T) bool { | |
for _, fn := range filters { | |
if !fn(item) { | |
return false | |
} | |
} | |
return true | |
} | |
i := 0 | |
out := make([]T, len(items)) | |
for _, el := range items { | |
if all(el) { | |
out[i] = el | |
i++ | |
} | |
} | |
return out[:i] | |
} | |
func Objects(scope *types.Scope) []types.Object { | |
out := make([]types.Object, len(scope.Names())) | |
for i, name := range scope.Names() { | |
out[i] = scope.Lookup(name) | |
} | |
return out | |
} | |
func IsExported(obj types.Object) bool { | |
return obj.Exported() | |
} | |
func IsStruct(obj types.Object) bool { | |
_, ok := obj.Type().Underlying().(*types.Struct) | |
return ok | |
} | |
func Implements(iface *types.Interface) func(types.Object) bool { | |
return func(obj types.Object) bool { | |
ptr := types.NewPointer(obj.Type()) | |
return types.Implements(obj.Type(), iface) || types.Implements(ptr, iface) | |
} | |
} | |
func Contains(n ast.Node, pos token.Pos) bool { | |
return n.Pos() <= pos && n.End() >= pos | |
} | |
// ToAst returns the ast.Node or specified type that is most closely associated | |
// with the position pos. | |
func ToAst[T ast.Node](pkg *packages.Package, pos token.Pos) T { | |
var node T | |
found := false | |
for _, f := range pkg.Syntax { | |
ast.Inspect(f, func(n ast.Node) bool { | |
if found || n == nil || !Contains(n, pos) { | |
return false | |
} | |
if cast, ok := n.(T); ok { | |
node = cast | |
found = true | |
return false | |
} | |
return true | |
}) | |
if found { | |
break | |
} | |
} | |
return node | |
} | |
// IsDDL does its best to determine if an implementor of Statement returns the | |
// DDL constant from its StatementReturnType method. | |
func IsDDL(pkg *packages.Package) func(types.Object) bool { | |
return func(obj types.Object) bool { | |
method := MethodByName(obj.Type().(*types.Named), "StatementReturnType") | |
fn := ToAst[*ast.FuncDecl](pkg, method.Pos()) | |
isDDL := false | |
ast.Inspect(fn.Body, func(n ast.Node) bool { | |
if ident, ok := n.(*ast.Ident); ok && ident.Name == "DDL" { | |
isDDL = true | |
return false | |
} | |
return true | |
}) | |
return isDDL | |
} | |
} | |
func main() { | |
pkgs, err := packages.Load(&packages.Config{ | |
Mode: packages.NeedFiles | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax, | |
}, "./pkg/sql/sem/tree") | |
if err != nil { | |
panic(err) | |
} | |
pkg := pkgs[0] | |
statementIface := pkg.Types.Scope().Lookup("Statement").Type().(*types.Named).Underlying().(*types.Interface) | |
nodeFormatterIface := pkg.Types.Scope().Lookup("NodeFormatter").Type().(*types.Named).Underlying().(*types.Interface) | |
// Find DDL Statements by finding all public structs that implement Statement | |
// which return DDL from StatementReturnType. | |
objs := Objects(pkg.Types.Scope()) | |
publicStructs := Filter(objs, IsExported, IsStruct) | |
statements := Filter(publicStructs, Implements(statementIface)) | |
ddl := Filter(statements, IsDDL(pkg)) | |
// Find all ddl subcommands, like AlterTableAddColumn, by finding all public | |
// structs that implement NodeFormatter who's name starts with a DDL | |
// statement's name. | |
// This is pretty loose and might find false positives or skip subcommands. | |
nodeFormatters := Filter(publicStructs, Implements(nodeFormatterIface)) | |
ddlSubCommands := Filter(nodeFormatters, func(obj types.Object) bool { | |
for _, stmt := range ddl { | |
if obj.Name() == stmt.Name() { | |
return false | |
} | |
} | |
for _, stmt := range ddl { | |
if strings.HasPrefix(obj.Name(), stmt.Name()) { | |
return true | |
} | |
} | |
return false | |
}) | |
// Knit together the final list. We have to do some manual post processing | |
// any how, so don't worry about duplicates. | |
ddl = append(ddl, ddlSubCommands...) | |
sort.Slice(ddl, func(i, j int) bool { | |
return ddl[i].Name() < ddl[j].Name() | |
}) | |
for _, obj := range ddl { | |
name := []byte(obj.Name()) | |
name[0] = bytes.ToLower(name[:1])[0] | |
fmt.Printf("%s\n", name) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment