Created
May 20, 2025 14:29
-
-
Save caseylmanus/47e0245ed9de90e43a410a432ea43b9a to your computer and use it in GitHub Desktop.
rewrite grpc imports
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"go/ast" | |
"go/parser" | |
"go/printer" | |
"go/token" | |
"go/types" | |
"os" | |
"os/exec" | |
"path/filepath" | |
"strings" | |
) | |
// Constants for the old import path root. | |
const ( | |
oldImportPathRoot = "your/old/import/path" // Replace with your actual old import path root | |
) | |
// packageInfo holds information about a parsed package. | |
type packageInfo struct { | |
types map[string]*types.TypeName // Map of type names to their definitions. | |
pkg *types.Package | |
} | |
// parsePackage parses a package and returns its type information. | |
func parsePackage(path string, fset *token.FileSet) (*packageInfo, error) { | |
pkgs, err := parser.ParseDir(fset, path, nil, 0) | |
if err != nil { | |
return nil, err | |
} | |
// We assume there's only one package in the directory. | |
var pkg *ast.Package | |
for _, p := range pkgs { | |
pkg = p | |
break | |
} | |
if pkg == nil { | |
return nil, nil // No Go files found. | |
} | |
// Create a slice of ast.Files for type checking. | |
files := make([]*ast.File, 0, len(pkg.Files)) | |
for _, file := range pkg.Files { | |
files = append(files, file) | |
} | |
// Type check the package to get type information. | |
conf := types.Config{} | |
info := &types.Info{ | |
Defs: make(map[*ast.Ident]types.Object), | |
Uses: make(map[*ast.Ident]types.Object), | |
} | |
typesPkg, err := conf.Check(path, fset, files, info) | |
if err != nil { | |
return nil, err | |
} | |
// Extract the exported types. | |
typesMap := make(map[string]*types.TypeName) | |
for _, scopeName := range typesPkg.Scope().Names() { | |
obj := typesPkg.Scope().Lookup(scopeName) | |
if typeName, ok := obj.(*types.TypeName); ok && obj.Exported() { | |
typesMap[scopeName] = typeName | |
} | |
} | |
return &packageInfo{ | |
types: typesMap, | |
pkg: typesPkg, | |
}, nil | |
} | |
// formatGRPCImportPath formats the gRPC import path. | |
// You will need to provide the logic to generate the correct gRPC import path. | |
func formatGRPCImportPath(originalPath string) string { | |
// Replace this with your complex logic to generate the correct gRPC import path. | |
// Example: | |
// newPath := strings.Replace(originalPath, oldImportPathRoot, newGRPCImportPathRoot, 1) | |
// newPath = filepath.Join(newPath, "grpc") // Add "grpc" to the path. | |
return "" // Placeholder | |
} | |
// formatProtoImportPath formats the protobuf import path. | |
func formatProtoImportPath(originalPath string) string { | |
// Replace this with your complex logic to generate the correct protobuf import path. | |
// Example: | |
// newPath := strings.Replace(originalPath, oldImportPathRoot, newProtoImportPathRoot, 1) | |
return "" // Placeholder | |
} | |
// grpcAliasName generates the alias name for gRPC imports. | |
func grpcAliasName(originalPath string) string { | |
// Implement your logic to generate gRPC alias. | |
// Example: | |
return strings.ReplaceAll(filepath.Base(originalPath), "-", "") + "grpc" | |
} | |
// protoAliasName generates the alias name for protobuf imports. | |
func protoAliasName(originalPath string) string { | |
// Implement your logic to generate proto alias. | |
// Example: | |
return strings.ReplaceAll(filepath.Base(originalPath), "-", "") + "proto" | |
} | |
// fetchDependencies uses `go get` to download the required packages. | |
func fetchDependencies(grpcImportPath, protoImportPath string) error { | |
if grpcImportPath != "" { | |
cmd := exec.Command("go", "get", grpcImportPath) | |
cmd.Stdout = os.Stdout // Optional: Show the output of the command | |
cmd.Stderr = os.Stderr | |
if err := cmd.Run(); err != nil { | |
return fmt.Errorf("failed to go get %s: %w", grpcImportPath, err) | |
} | |
fmt.Printf("Successfully fetched %s\n", grpcImportPath) | |
} | |
if protoImportPath != "" { | |
cmd := exec.Command("go", "get", protoImportPath) | |
cmd.Stdout = os.Stdout // Optional: Show the output of the command | |
cmd.Stderr = os.Stderr | |
if err := cmd.Run(); err != nil { | |
return fmt.Errorf("failed to go get %s: %w", protoImportPath, err) | |
} | |
fmt.Printf("Successfully fetched %s\n", protoImportPath) | |
} | |
return nil | |
} | |
// updateImportsInFile uses the AST to rewrite imports and code. | |
func updateImportsInFile(filePath string, protoInfo, grpcInfo *packageInfo) error { | |
// 1. Parse the file using the Go AST. | |
fset := token.NewFileSet() | |
node, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments) | |
if err != nil { | |
return err | |
} | |
// 2. Type check the file to resolve identifiers. | |
conf := types.Config{} | |
info := &types.Info{ | |
Defs: make(map[*ast.Ident]types.Object), | |
Uses: make(map[*ast.Ident]types.Object), | |
} | |
_, err = conf.Check("", fset, []*ast.File{node}, info) | |
if err != nil { | |
return err | |
} | |
// 3. Track whether any changes were made and which imports are needed. | |
hasChanged := false | |
needsGRPCImport := false | |
needsProtoImport := false | |
originalImportPath := "" | |
grpcTypesMap := make(map[string]bool) | |
protoTypesMap := make(map[string]bool) | |
if grpcInfo != nil { | |
for typeName := range grpcInfo.types { | |
grpcTypesMap[typeName] = true | |
} | |
} | |
if protoInfo != nil { | |
for typeName := range protoInfo.types { | |
protoTypesMap[typeName] = true | |
} | |
} | |
// 4. Visit the AST nodes to find import declarations and identifiers. | |
ast.Inspect(node, func(n ast.Node) bool { | |
switch v := n.(type) { | |
case *ast.ImportSpec: | |
// 5. Remove the old import. We'll add the correct ones later. | |
if v.Path != nil && strings.HasPrefix(strings.Trim(v.Path.Value, `"`), oldImportPathRoot) { | |
hasChanged = true | |
originalImportPath = strings.Trim(v.Path.Value, `"`) // Capture the original import path | |
return false | |
} | |
case *ast.Ident: | |
// 6. Check usages to determine needed imports. | |
if obj := info.Uses[v]; obj != nil { | |
if typ, ok := obj.Type().(*types.Named); ok { | |
typeName := typ.Obj().Name() | |
// Determine which package the type is now in. | |
if protoInfo != nil && protoTypesMap[typeName] { | |
needsProtoImport = true | |
} else if grpcInfo != nil && grpcTypesMap[typeName] { | |
needsGRPCImport = true | |
} | |
} | |
} | |
} | |
return true | |
}) | |
// 7. Add the necessary imports. | |
var grpcImportSpec *ast.ImportSpec | |
var protoImportSpec *ast.ImportSpec | |
if hasChanged || needsGRPCImport || needsProtoImport { | |
// Remove all imports first. | |
node.Imports = nil | |
if needsProtoImport { | |
newProtoImportPath := formatProtoImportPath(originalImportPath) // Use the formatter. | |
protoImportSpec = &ast.ImportSpec{ | |
Name: &ast.Ident{Name: protoAliasName(originalImportPath)}, | |
Path: &ast.BasicLit{ | |
Kind: token.STRING, | |
Value: `"` + newProtoImportPath + `"`, | |
}, | |
} | |
node.Imports = append(node.Imports, &ast.GenDecl{ | |
Tok: token.IMPORT, | |
Specs: []ast.Spec{protoImportSpec}, | |
}) | |
} | |
if needsGRPCImport { | |
newGRPCImportPath := formatGRPCImportPath(originalImportPath) // Use the formatter | |
grpcImportSpec = &ast.ImportSpec{ | |
Name: &ast.Ident{Name: grpcAliasName(originalImportPath)}, | |
Path: &ast.BasicLit{ | |
Kind: token.STRING, | |
Value: `"` + newGRPCImportPath + `"`, | |
}, | |
} | |
node.Imports = append(node.Imports, &ast.GenDecl{ | |
Tok: token.IMPORT, | |
Specs: []ast.Spec{grpcImportSpec}, | |
}) | |
} | |
// 8. Update the references in the code. | |
ast.Inspect(node, func(n ast.Node) bool { | |
if sel, ok := n.(*ast.SelectorExpr); ok { | |
if xIdent, ok := sel.X.(*ast.Ident); ok { | |
// Check if the selector's X is a package identifier | |
if xIdent.Obj != nil && xIdent.Obj.Kind == ast.Pkg { | |
if strings.HasPrefix(xIdent.Name, filepath.Base(oldImportPathRoot)) { | |
if protoImportSpec != nil && xIdent.Name == protoAliasName(originalImportPath) { | |
xIdent.Name = protoAliasName(originalImportPath) | |
hasChanged = true | |
} else if grpcImportSpec != nil && xIdent.Name == grpcAliasName(originalImportPath) { | |
xIdent.Name = grpcAliasName(originalImportPath) | |
hasChanged = true | |
} | |
} | |
} | |
} | |
} | |
return true | |
}) | |
// 9. Print the modified AST back to the file. | |
if hasChanged { | |
file, err := os.Create(filePath) // Overwrite the original file. | |
if err != nil { | |
return err | |
} | |
defer file.Close() | |
err = printer.Fprint(file, fset, node) | |
if err != nil { | |
return err | |
} | |
fmt.Printf("Updated imports and code in %s\n", filePath) | |
} else { | |
fmt.Printf("No changes needed in %s\n", filePath) | |
} | |
} else { | |
fmt.Printf("No changes needed in %s\n", filePath) | |
} | |
return nil | |
} | |
// addImport adds an import declaration to the AST. | |
func addImport(node *ast.File, importPath string) { | |
importSpec := &ast.ImportSpec{ | |
Path: &ast.BasicLit{ | |
Kind: token.STRING, | |
Value: `"` + importPath + `"`, | |
}, | |
} | |
node.Imports = append(node.Imports, &ast.GenDecl{ | |
Tok: token.IMPORT, | |
Specs: []ast.Spec{importSpec}, | |
}) | |
} | |
// processDirectory recursively finds Go files and updates them. | |
func processDirectory(dirPath string) error { | |
fset := token.NewFileSet() | |
// Determine the new import paths. | |
grpcImportPath := formatGRPCImportPath(oldImportPathRoot) | |
protoImportPath := formatProtoImportPath(oldImportPathRoot) | |
// Fetch the dependencies. | |
if err := fetchDependencies(grpcImportPath, protoImportPath); err != nil { | |
return err | |
} | |
// Parse the new proto and gRPC packages. | |
protoInfo, err := parsePackage(filepath.Join(dirPath, strings.Replace(newProtoImportPathRoot, oldImportPathRoot, "", 1)), fset) | |
if err != nil { | |
return err | |
} | |
grpcInfo, err := parsePackage(filepath.Join(dirPath, strings.Replace(newGRPCImportPathRoot, oldImportPathRoot, "", 1)), fset) | |
if err != nil { | |
return err | |
} | |
return filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error { | |
if err != nil { | |
return err | |
} | |
if !info.IsDir() && strings.HasSuffix(path, ".go") { | |
err := updateImportsInFile(path, protoInfo, grpcInfo) | |
if err != nil { | |
return err | |
} | |
} | |
return nil | |
}) | |
} | |
func main() { | |
if len(os.Args) != 2 { | |
fmt.Println("Usage: go run main.go <directory_path>") | |
os.Exit(1) | |
} | |
directoryPath := os.Args[1] | |
fmt.Printf("Processing directory: %s\n", directoryPath) | |
err := processDirectory(directoryPath) | |
if err != nil { | |
fmt.Printf("Error: %v\n", err) | |
os.Exit(1) | |
} | |
fmt.Println("Import update process completed.") | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment