Skip to content

Instantly share code, notes, and snippets.

@tecnologer
Created January 25, 2024 16:26
Show Gist options
  • Save tecnologer/d07cf649dd61e06fdf1392f47e9b5636 to your computer and use it in GitHub Desktop.
Save tecnologer/d07cf649dd61e06fdf1392f47e9b5636 to your computer and use it in GitHub Desktop.
Parses Go's functions to JSON Schema to use with OpenAI API functions
package main
import (
"flag"
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"regexp"
"strings"
)
var (
functionsFlag = flag.String("functions", "", "List of functions to generate, comma separated")
filesFlag = flag.String("files", "", "List of files to parse, comma separated")
outputPath = flag.String("output", "functions.json", "Output path for the generated functions JSON file")
)
func main() {
flag.Parse()
if *filesFlag == "" {
nlogger.Fatal("files are required")
}
if *functionsFlag == "" {
nlogger.Fatal("functions are required")
}
if *outputPath == "" {
*outputPath = "./functions.json"
}
if !strings.HasSuffix(*outputPath, ".json") {
*outputPath += ".json"
}
goFiles := strings.Split(*filesFlag, ",")
functionNames := strings.Split(*functionsFlag, ",")
functions, err := inspectFiles(goFiles, functionNames)
if err != nil {
nlogger.Fatal(err)
}
_ = file.WriteJSONFile(functions.Definitions, *outputPath)
}
func inspectFiles(goFiles []string, functionNames []string) (*factory.Functions, error) {
functions := factory.NewFunctions()
for _, goFile := range goFiles {
nlogger.Infof("Inspecting file: %s", goFile)
fns, err := buildFunction(goFile, functionNames...)
if err != nil {
return nil, err
}
functions.Merge(fns)
}
return functions, nil
}
func buildFunction(filename string, functionNames ...string) (*factory.Functions, error) {
if !file.ExistsFile(filename) {
return nil, fmt.Errorf("file %s does not exist", filename)
}
src, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
// Create a new FileSet.
fset := token.NewFileSet()
// Parse the .go file.
goFile, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
if err != nil {
return nil, err
}
functions := factory.NewFunctions()
ast.Inspect(goFile, func(n ast.Node) bool {
fn, ok := n.(*ast.FuncDecl)
if !ok {
return true
}
functions.Merge(processFunctions(fn, functionNames))
return true
})
return functions, nil
}
func processFunctions(fn *ast.FuncDecl, functionNames []string) *factory.Functions {
functions := factory.NewFunctions()
for _, functionName := range functionNames {
nlogger.Infof("Processing function: %s", functionName)
if fn.Name.Name != functionName {
continue
}
godoc := extractGoDoc(fn.Doc.Text())
defs := extractFromGoDoc(godoc)
retrieveArgTypes(&defs, fn)
params := buildFunctionParameters(defs)
functions.Add(factory.NewFunction(
factory.WithFunctionName(defs.Name),
factory.WithFunctionDescription(defs.Description),
factory.WithFunctionParameters(params...),
))
}
return functions
}
func retrieveArgTypes(defs *GoDocFunc, fn *ast.FuncDecl) {
for i, param := range fn.Type.Params.List {
pName := param.Names[0].Name
if defs.Args[pName] == nil {
continue
}
argType := getArgType(defs.Args[pName], param.Type)
defs.Args[pName].Type = argType
defs.Args[pName].Order = i
}
}
func getArgType(param *GoDocArg, paramType ast.Expr) string {
if ident, ok := paramType.(*ast.Ident); ok {
return ident.Name
} else if star, ok := paramType.(*ast.StarExpr); ok {
ident := star.X.(*ast.Ident)
typeSpec := ident.Obj.Decl.(*ast.TypeSpec)
structType, ok := typeSpec.Type.(*ast.StructType)
// if it's not a struct type, it's not an object
if !ok {
return "string"
}
param.Args = make(map[string]*GoDocArg)
order := 0
for _, field := range structType.Fields.List {
if len(field.Names) == 0 {
continue
}
description := strings.TrimPrefix(field.Comment.List[0].Text, "//")
description = strings.TrimSpace(description)
fieldName := nameFromJSONTag(field.Tag)
if fieldName == "" {
fieldName = field.Names[0].Name
}
param.Args[fieldName] = &GoDocArg{
Type: field.Type.(*ast.Ident).Name,
Description: description,
Required: validatorTagIsRequired(field.Tag),
Order: order,
}
order++
}
return "struct"
}
return "string"
}
func extractGoDoc(doc string) string {
var (
godoc = strings.Builder{}
lines = strings.Split(doc, "\n")
)
for _, line := range lines {
if strings.HasPrefix(line, "//") {
line = strings.TrimPrefix(line, "//")
line = strings.TrimSpace(line)
}
if line != "" {
godoc.WriteString(line + "\n")
}
}
return godoc.String()
}
var (
funcMatch = regexp.MustCompile(`(?ms)^(?P<func_name>\w+)\s+(?P<func_desc>[\s\S]+?)\s*\n\s*-`)
argMatch = regexp.MustCompile(
`(?m)^\s*-\s+(?P<arg_name>\w+):\s+(?P<arg_desc>[^,\n]+)(?:,\s*(?P<required>required))?(?:,\s*enum:\s+(?P<enums>[^\n]+))?`,
)
)
type GoDocFunc struct {
Name string
Description string
Args map[string]*GoDocArg
}
type GoDocArg struct {
Description string
Required bool
Type string
Enums []string
Order int
Args map[string]*GoDocArg
}
func extractFromGoDoc(text string) GoDocFunc {
// Assuming 'text' is the input string containing the function details
functionMatches := funcMatch.FindStringSubmatch(text)
functionName := functionMatches[1]
functionDescription := strings.ReplaceAll(functionMatches[2], "\n", " ")
nlogger.Infof("Function Name: %s\n", functionName)
nlogger.Infof("Function Description: %s\n", functionDescription)
fn := GoDocFunc{
Name: functionName,
Description: functionDescription,
Args: make(map[string]*GoDocArg),
}
argMatches := argMatch.FindAllStringSubmatch(text, -1)
for _, m := range argMatches {
argName := m[1]
argDescription := m[2]
// `required` will be an empty string if not present
argRequired := strings.TrimSpace(m[3]) == "required"
argEnums := m[4]
nlogger.Infof("Arg Name: %s\nDescription: %s\nRequired: %v\nEnums: %s\n",
argName, argDescription, argRequired, argEnums)
var enums []string
if argEnums != "" {
enums = strings.Split(argEnums, ",")
for i, e := range enums {
enums[i] = strings.TrimSpace(e)
}
}
fn.Args[argName] = &GoDocArg{
Description: strings.ReplaceAll(argDescription, "\n", " "),
Required: argRequired,
Enums: enums,
}
}
return fn
}
func buildFunctionParameters(defs GoDocFunc) []*factory.Parameter {
params := make([]*factory.Parameter, len(defs.Args))
i := 0
for argName, arg := range defs.Args {
opts := []factory.ParameterFactoryOption{
factory.WithParameterDescription(arg.Description),
factory.ParameterTypeString(arg.Type),
}
if arg.Required {
opts = append(opts, factory.ParameterRequired())
}
if len(arg.Enums) > 0 {
opts = append(opts, factory.WithParameterEnum(arg.Enums...))
}
params[i] = factory.NewParameter(
argName,
opts...,
)
params[i].Order = arg.Order
if len(arg.Args) > 0 {
params[i].Parameters = buildFunctionParameters(GoDocFunc{
Args: arg.Args,
})
}
i++
}
return params
}
func validatorTagIsRequired(tag *ast.BasicLit) bool {
if tag == nil || tag.Value == "" {
return false
}
tagValue := strings.Trim(tag.Value, "`") // Remove backticks
tagParts := strings.Split(tagValue, " ")
for _, part := range tagParts {
if !strings.HasPrefix(part, "validate:") {
continue
}
return strings.Contains(part, "required")
}
return false
}
func nameFromJSONTag(tag *ast.BasicLit) string {
if tag == nil || tag.Value == "" {
return ""
}
value := tag.Value
for _, part := range strings.Split(value, " ") {
part = strings.Trim(part, "`")
if !strings.HasPrefix(part, "json:") {
continue
}
value = part
break
}
value = strings.Trim(value, "`")
value = strings.TrimPrefix(value, "json:")
value = strings.Trim(value, "\"")
value = strings.TrimSuffix(value, "omitempty")
value = strings.Trim(value, ",")
value = strings.TrimSpace(value)
return value
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment