Created
January 25, 2024 16:26
-
-
Save tecnologer/d07cf649dd61e06fdf1392f47e9b5636 to your computer and use it in GitHub Desktop.
Parses Go's functions to JSON Schema to use with OpenAI API functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"flag" | |
"fmt" | |
"go/ast" | |
"go/parser" | |
"go/token" | |
"os" | |
"regexp" | |
"strings" | |
) | |
var ( | |
functionsFlag = flag.String("functions", "", "List of functions to generate, comma separated") | |
filesFlag = flag.String("files", "", "List of files to parse, comma separated") | |
outputPath = flag.String("output", "functions.json", "Output path for the generated functions JSON file") | |
) | |
func main() { | |
flag.Parse() | |
if *filesFlag == "" { | |
nlogger.Fatal("files are required") | |
} | |
if *functionsFlag == "" { | |
nlogger.Fatal("functions are required") | |
} | |
if *outputPath == "" { | |
*outputPath = "./functions.json" | |
} | |
if !strings.HasSuffix(*outputPath, ".json") { | |
*outputPath += ".json" | |
} | |
goFiles := strings.Split(*filesFlag, ",") | |
functionNames := strings.Split(*functionsFlag, ",") | |
functions, err := inspectFiles(goFiles, functionNames) | |
if err != nil { | |
nlogger.Fatal(err) | |
} | |
_ = file.WriteJSONFile(functions.Definitions, *outputPath) | |
} | |
func inspectFiles(goFiles []string, functionNames []string) (*factory.Functions, error) { | |
functions := factory.NewFunctions() | |
for _, goFile := range goFiles { | |
nlogger.Infof("Inspecting file: %s", goFile) | |
fns, err := buildFunction(goFile, functionNames...) | |
if err != nil { | |
return nil, err | |
} | |
functions.Merge(fns) | |
} | |
return functions, nil | |
} | |
func buildFunction(filename string, functionNames ...string) (*factory.Functions, error) { | |
if !file.ExistsFile(filename) { | |
return nil, fmt.Errorf("file %s does not exist", filename) | |
} | |
src, err := os.ReadFile(filename) | |
if err != nil { | |
return nil, err | |
} | |
// Create a new FileSet. | |
fset := token.NewFileSet() | |
// Parse the .go file. | |
goFile, err := parser.ParseFile(fset, filename, src, parser.ParseComments) | |
if err != nil { | |
return nil, err | |
} | |
functions := factory.NewFunctions() | |
ast.Inspect(goFile, func(n ast.Node) bool { | |
fn, ok := n.(*ast.FuncDecl) | |
if !ok { | |
return true | |
} | |
functions.Merge(processFunctions(fn, functionNames)) | |
return true | |
}) | |
return functions, nil | |
} | |
func processFunctions(fn *ast.FuncDecl, functionNames []string) *factory.Functions { | |
functions := factory.NewFunctions() | |
for _, functionName := range functionNames { | |
nlogger.Infof("Processing function: %s", functionName) | |
if fn.Name.Name != functionName { | |
continue | |
} | |
godoc := extractGoDoc(fn.Doc.Text()) | |
defs := extractFromGoDoc(godoc) | |
retrieveArgTypes(&defs, fn) | |
params := buildFunctionParameters(defs) | |
functions.Add(factory.NewFunction( | |
factory.WithFunctionName(defs.Name), | |
factory.WithFunctionDescription(defs.Description), | |
factory.WithFunctionParameters(params...), | |
)) | |
} | |
return functions | |
} | |
func retrieveArgTypes(defs *GoDocFunc, fn *ast.FuncDecl) { | |
for i, param := range fn.Type.Params.List { | |
pName := param.Names[0].Name | |
if defs.Args[pName] == nil { | |
continue | |
} | |
argType := getArgType(defs.Args[pName], param.Type) | |
defs.Args[pName].Type = argType | |
defs.Args[pName].Order = i | |
} | |
} | |
func getArgType(param *GoDocArg, paramType ast.Expr) string { | |
if ident, ok := paramType.(*ast.Ident); ok { | |
return ident.Name | |
} else if star, ok := paramType.(*ast.StarExpr); ok { | |
ident := star.X.(*ast.Ident) | |
typeSpec := ident.Obj.Decl.(*ast.TypeSpec) | |
structType, ok := typeSpec.Type.(*ast.StructType) | |
// if it's not a struct type, it's not an object | |
if !ok { | |
return "string" | |
} | |
param.Args = make(map[string]*GoDocArg) | |
order := 0 | |
for _, field := range structType.Fields.List { | |
if len(field.Names) == 0 { | |
continue | |
} | |
description := strings.TrimPrefix(field.Comment.List[0].Text, "//") | |
description = strings.TrimSpace(description) | |
fieldName := nameFromJSONTag(field.Tag) | |
if fieldName == "" { | |
fieldName = field.Names[0].Name | |
} | |
param.Args[fieldName] = &GoDocArg{ | |
Type: field.Type.(*ast.Ident).Name, | |
Description: description, | |
Required: validatorTagIsRequired(field.Tag), | |
Order: order, | |
} | |
order++ | |
} | |
return "struct" | |
} | |
return "string" | |
} | |
func extractGoDoc(doc string) string { | |
var ( | |
godoc = strings.Builder{} | |
lines = strings.Split(doc, "\n") | |
) | |
for _, line := range lines { | |
if strings.HasPrefix(line, "//") { | |
line = strings.TrimPrefix(line, "//") | |
line = strings.TrimSpace(line) | |
} | |
if line != "" { | |
godoc.WriteString(line + "\n") | |
} | |
} | |
return godoc.String() | |
} | |
var ( | |
funcMatch = regexp.MustCompile(`(?ms)^(?P<func_name>\w+)\s+(?P<func_desc>[\s\S]+?)\s*\n\s*-`) | |
argMatch = regexp.MustCompile( | |
`(?m)^\s*-\s+(?P<arg_name>\w+):\s+(?P<arg_desc>[^,\n]+)(?:,\s*(?P<required>required))?(?:,\s*enum:\s+(?P<enums>[^\n]+))?`, | |
) | |
) | |
type GoDocFunc struct { | |
Name string | |
Description string | |
Args map[string]*GoDocArg | |
} | |
type GoDocArg struct { | |
Description string | |
Required bool | |
Type string | |
Enums []string | |
Order int | |
Args map[string]*GoDocArg | |
} | |
func extractFromGoDoc(text string) GoDocFunc { | |
// Assuming 'text' is the input string containing the function details | |
functionMatches := funcMatch.FindStringSubmatch(text) | |
functionName := functionMatches[1] | |
functionDescription := strings.ReplaceAll(functionMatches[2], "\n", " ") | |
nlogger.Infof("Function Name: %s\n", functionName) | |
nlogger.Infof("Function Description: %s\n", functionDescription) | |
fn := GoDocFunc{ | |
Name: functionName, | |
Description: functionDescription, | |
Args: make(map[string]*GoDocArg), | |
} | |
argMatches := argMatch.FindAllStringSubmatch(text, -1) | |
for _, m := range argMatches { | |
argName := m[1] | |
argDescription := m[2] | |
// `required` will be an empty string if not present | |
argRequired := strings.TrimSpace(m[3]) == "required" | |
argEnums := m[4] | |
nlogger.Infof("Arg Name: %s\nDescription: %s\nRequired: %v\nEnums: %s\n", | |
argName, argDescription, argRequired, argEnums) | |
var enums []string | |
if argEnums != "" { | |
enums = strings.Split(argEnums, ",") | |
for i, e := range enums { | |
enums[i] = strings.TrimSpace(e) | |
} | |
} | |
fn.Args[argName] = &GoDocArg{ | |
Description: strings.ReplaceAll(argDescription, "\n", " "), | |
Required: argRequired, | |
Enums: enums, | |
} | |
} | |
return fn | |
} | |
func buildFunctionParameters(defs GoDocFunc) []*factory.Parameter { | |
params := make([]*factory.Parameter, len(defs.Args)) | |
i := 0 | |
for argName, arg := range defs.Args { | |
opts := []factory.ParameterFactoryOption{ | |
factory.WithParameterDescription(arg.Description), | |
factory.ParameterTypeString(arg.Type), | |
} | |
if arg.Required { | |
opts = append(opts, factory.ParameterRequired()) | |
} | |
if len(arg.Enums) > 0 { | |
opts = append(opts, factory.WithParameterEnum(arg.Enums...)) | |
} | |
params[i] = factory.NewParameter( | |
argName, | |
opts..., | |
) | |
params[i].Order = arg.Order | |
if len(arg.Args) > 0 { | |
params[i].Parameters = buildFunctionParameters(GoDocFunc{ | |
Args: arg.Args, | |
}) | |
} | |
i++ | |
} | |
return params | |
} | |
func validatorTagIsRequired(tag *ast.BasicLit) bool { | |
if tag == nil || tag.Value == "" { | |
return false | |
} | |
tagValue := strings.Trim(tag.Value, "`") // Remove backticks | |
tagParts := strings.Split(tagValue, " ") | |
for _, part := range tagParts { | |
if !strings.HasPrefix(part, "validate:") { | |
continue | |
} | |
return strings.Contains(part, "required") | |
} | |
return false | |
} | |
func nameFromJSONTag(tag *ast.BasicLit) string { | |
if tag == nil || tag.Value == "" { | |
return "" | |
} | |
value := tag.Value | |
for _, part := range strings.Split(value, " ") { | |
part = strings.Trim(part, "`") | |
if !strings.HasPrefix(part, "json:") { | |
continue | |
} | |
value = part | |
break | |
} | |
value = strings.Trim(value, "`") | |
value = strings.TrimPrefix(value, "json:") | |
value = strings.Trim(value, "\"") | |
value = strings.TrimSuffix(value, "omitempty") | |
value = strings.Trim(value, ",") | |
value = strings.TrimSpace(value) | |
return value | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment