Skip to content

Instantly share code, notes, and snippets.

@caseylmanus
Created May 20, 2025 14:29
Show Gist options
  • Save caseylmanus/47e0245ed9de90e43a410a432ea43b9a to your computer and use it in GitHub Desktop.
Save caseylmanus/47e0245ed9de90e43a410a432ea43b9a to your computer and use it in GitHub Desktop.
rewrite grpc imports
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"go/types"
"os"
"os/exec"
"path/filepath"
"strings"
)
// Constants for the old import path root.
const (
oldImportPathRoot = "your/old/import/path" // Replace with your actual old import path root
)
// packageInfo holds information about a parsed package.
type packageInfo struct {
types map[string]*types.TypeName // Map of type names to their definitions.
pkg *types.Package
}
// parsePackage parses a package and returns its type information.
func parsePackage(path string, fset *token.FileSet) (*packageInfo, error) {
pkgs, err := parser.ParseDir(fset, path, nil, 0)
if err != nil {
return nil, err
}
// We assume there's only one package in the directory.
var pkg *ast.Package
for _, p := range pkgs {
pkg = p
break
}
if pkg == nil {
return nil, nil // No Go files found.
}
// Create a slice of ast.Files for type checking.
files := make([]*ast.File, 0, len(pkg.Files))
for _, file := range pkg.Files {
files = append(files, file)
}
// Type check the package to get type information.
conf := types.Config{}
info := &types.Info{
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
}
typesPkg, err := conf.Check(path, fset, files, info)
if err != nil {
return nil, err
}
// Extract the exported types.
typesMap := make(map[string]*types.TypeName)
for _, scopeName := range typesPkg.Scope().Names() {
obj := typesPkg.Scope().Lookup(scopeName)
if typeName, ok := obj.(*types.TypeName); ok && obj.Exported() {
typesMap[scopeName] = typeName
}
}
return &packageInfo{
types: typesMap,
pkg: typesPkg,
}, nil
}
// formatGRPCImportPath formats the gRPC import path.
// You will need to provide the logic to generate the correct gRPC import path.
func formatGRPCImportPath(originalPath string) string {
// Replace this with your complex logic to generate the correct gRPC import path.
// Example:
// newPath := strings.Replace(originalPath, oldImportPathRoot, newGRPCImportPathRoot, 1)
// newPath = filepath.Join(newPath, "grpc") // Add "grpc" to the path.
return "" // Placeholder
}
// formatProtoImportPath formats the protobuf import path.
func formatProtoImportPath(originalPath string) string {
// Replace this with your complex logic to generate the correct protobuf import path.
// Example:
// newPath := strings.Replace(originalPath, oldImportPathRoot, newProtoImportPathRoot, 1)
return "" // Placeholder
}
// grpcAliasName generates the alias name for gRPC imports.
func grpcAliasName(originalPath string) string {
// Implement your logic to generate gRPC alias.
// Example:
return strings.ReplaceAll(filepath.Base(originalPath), "-", "") + "grpc"
}
// protoAliasName generates the alias name for protobuf imports.
func protoAliasName(originalPath string) string {
// Implement your logic to generate proto alias.
// Example:
return strings.ReplaceAll(filepath.Base(originalPath), "-", "") + "proto"
}
// fetchDependencies uses `go get` to download the required packages.
func fetchDependencies(grpcImportPath, protoImportPath string) error {
if grpcImportPath != "" {
cmd := exec.Command("go", "get", grpcImportPath)
cmd.Stdout = os.Stdout // Optional: Show the output of the command
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to go get %s: %w", grpcImportPath, err)
}
fmt.Printf("Successfully fetched %s\n", grpcImportPath)
}
if protoImportPath != "" {
cmd := exec.Command("go", "get", protoImportPath)
cmd.Stdout = os.Stdout // Optional: Show the output of the command
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to go get %s: %w", protoImportPath, err)
}
fmt.Printf("Successfully fetched %s\n", protoImportPath)
}
return nil
}
// updateImportsInFile uses the AST to rewrite imports and code.
func updateImportsInFile(filePath string, protoInfo, grpcInfo *packageInfo) error {
// 1. Parse the file using the Go AST.
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments)
if err != nil {
return err
}
// 2. Type check the file to resolve identifiers.
conf := types.Config{}
info := &types.Info{
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
}
_, err = conf.Check("", fset, []*ast.File{node}, info)
if err != nil {
return err
}
// 3. Track whether any changes were made and which imports are needed.
hasChanged := false
needsGRPCImport := false
needsProtoImport := false
originalImportPath := ""
grpcTypesMap := make(map[string]bool)
protoTypesMap := make(map[string]bool)
if grpcInfo != nil {
for typeName := range grpcInfo.types {
grpcTypesMap[typeName] = true
}
}
if protoInfo != nil {
for typeName := range protoInfo.types {
protoTypesMap[typeName] = true
}
}
// 4. Visit the AST nodes to find import declarations and identifiers.
ast.Inspect(node, func(n ast.Node) bool {
switch v := n.(type) {
case *ast.ImportSpec:
// 5. Remove the old import. We'll add the correct ones later.
if v.Path != nil && strings.HasPrefix(strings.Trim(v.Path.Value, `"`), oldImportPathRoot) {
hasChanged = true
originalImportPath = strings.Trim(v.Path.Value, `"`) // Capture the original import path
return false
}
case *ast.Ident:
// 6. Check usages to determine needed imports.
if obj := info.Uses[v]; obj != nil {
if typ, ok := obj.Type().(*types.Named); ok {
typeName := typ.Obj().Name()
// Determine which package the type is now in.
if protoInfo != nil && protoTypesMap[typeName] {
needsProtoImport = true
} else if grpcInfo != nil && grpcTypesMap[typeName] {
needsGRPCImport = true
}
}
}
}
return true
})
// 7. Add the necessary imports.
var grpcImportSpec *ast.ImportSpec
var protoImportSpec *ast.ImportSpec
if hasChanged || needsGRPCImport || needsProtoImport {
// Remove all imports first.
node.Imports = nil
if needsProtoImport {
newProtoImportPath := formatProtoImportPath(originalImportPath) // Use the formatter.
protoImportSpec = &ast.ImportSpec{
Name: &ast.Ident{Name: protoAliasName(originalImportPath)},
Path: &ast.BasicLit{
Kind: token.STRING,
Value: `"` + newProtoImportPath + `"`,
},
}
node.Imports = append(node.Imports, &ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{protoImportSpec},
})
}
if needsGRPCImport {
newGRPCImportPath := formatGRPCImportPath(originalImportPath) // Use the formatter
grpcImportSpec = &ast.ImportSpec{
Name: &ast.Ident{Name: grpcAliasName(originalImportPath)},
Path: &ast.BasicLit{
Kind: token.STRING,
Value: `"` + newGRPCImportPath + `"`,
},
}
node.Imports = append(node.Imports, &ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{grpcImportSpec},
})
}
// 8. Update the references in the code.
ast.Inspect(node, func(n ast.Node) bool {
if sel, ok := n.(*ast.SelectorExpr); ok {
if xIdent, ok := sel.X.(*ast.Ident); ok {
// Check if the selector's X is a package identifier
if xIdent.Obj != nil && xIdent.Obj.Kind == ast.Pkg {
if strings.HasPrefix(xIdent.Name, filepath.Base(oldImportPathRoot)) {
if protoImportSpec != nil && xIdent.Name == protoAliasName(originalImportPath) {
xIdent.Name = protoAliasName(originalImportPath)
hasChanged = true
} else if grpcImportSpec != nil && xIdent.Name == grpcAliasName(originalImportPath) {
xIdent.Name = grpcAliasName(originalImportPath)
hasChanged = true
}
}
}
}
}
return true
})
// 9. Print the modified AST back to the file.
if hasChanged {
file, err := os.Create(filePath) // Overwrite the original file.
if err != nil {
return err
}
defer file.Close()
err = printer.Fprint(file, fset, node)
if err != nil {
return err
}
fmt.Printf("Updated imports and code in %s\n", filePath)
} else {
fmt.Printf("No changes needed in %s\n", filePath)
}
} else {
fmt.Printf("No changes needed in %s\n", filePath)
}
return nil
}
// addImport adds an import declaration to the AST.
func addImport(node *ast.File, importPath string) {
importSpec := &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: `"` + importPath + `"`,
},
}
node.Imports = append(node.Imports, &ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{importSpec},
})
}
// processDirectory recursively finds Go files and updates them.
func processDirectory(dirPath string) error {
fset := token.NewFileSet()
// Determine the new import paths.
grpcImportPath := formatGRPCImportPath(oldImportPathRoot)
protoImportPath := formatProtoImportPath(oldImportPathRoot)
// Fetch the dependencies.
if err := fetchDependencies(grpcImportPath, protoImportPath); err != nil {
return err
}
// Parse the new proto and gRPC packages.
protoInfo, err := parsePackage(filepath.Join(dirPath, strings.Replace(newProtoImportPathRoot, oldImportPathRoot, "", 1)), fset)
if err != nil {
return err
}
grpcInfo, err := parsePackage(filepath.Join(dirPath, strings.Replace(newGRPCImportPathRoot, oldImportPathRoot, "", 1)), fset)
if err != nil {
return err
}
return filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(path, ".go") {
err := updateImportsInFile(path, protoInfo, grpcInfo)
if err != nil {
return err
}
}
return nil
})
}
func main() {
if len(os.Args) != 2 {
fmt.Println("Usage: go run main.go <directory_path>")
os.Exit(1)
}
directoryPath := os.Args[1]
fmt.Printf("Processing directory: %s\n", directoryPath)
err := processDirectory(directoryPath)
if err != nil {
fmt.Printf("Error: %v\n", err)
os.Exit(1)
}
fmt.Println("Import update process completed.")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment