Example Usage:
$ go build -i -o protoc-gen-grpc-gateway-ts .
$ protoc -I ./lib/proto ./lib/proto/*.proto
--plugin=protoc-gen-grpc_gateway_ts=which protoc-gen-grpc-gateway-ts
--grpc_gateway_ts_out=lib/ts/src
package main | |
import ( | |
"fmt" | |
"io" | |
"io/ioutil" | |
"log" | |
"os" | |
"path" | |
"sort" | |
"strings" | |
"github.com/golang/protobuf/proto" | |
"github.com/golang/protobuf/protoc-gen-go/descriptor" | |
gen "github.com/golang/protobuf/protoc-gen-go/generator" | |
plugin "github.com/golang/protobuf/protoc-gen-go/plugin" | |
) | |
type generator struct { | |
*gen.Generator | |
reader io.Reader | |
writer io.Writer | |
} | |
type messageField struct { | |
name string | |
t string | |
} | |
func namespace(name string) string { | |
return fmt.Sprintf(".%s", strings.Replace(path.Dir(name), "/", ".", -1)) | |
} | |
func qualifiedName(ns, name string) string { | |
return fmt.Sprintf("%s.%s", ns, name) | |
} | |
type serviceDefinition struct { | |
desc *descriptor.ServiceDescriptorProto | |
dependencies map[string]bool | |
} | |
type serviceMap map[string]*serviceDefinition | |
type messageDefinition struct { | |
desc *descriptor.DescriptorProto | |
dependencies map[string]bool | |
name string | |
fields []messageField | |
location string | |
} | |
type messageMap map[string]*messageDefinition | |
type enumDefinition struct { | |
desc *descriptor.EnumDescriptorProto | |
location string | |
} | |
type enumMap map[string]*enumDefinition | |
func fatal(err error, msg string) { | |
log.Printf("protoc-gen-grpc-gateway-ts: error: %s: %s", msg, err.Error()) | |
os.Exit(1) | |
} | |
func New() *generator { | |
return &generator{ | |
Generator: gen.New(), | |
reader: os.Stdin, | |
writer: os.Stdout, | |
} | |
} | |
// P prints the arguments to the generated output. It handles strings and int32s, plus | |
// handling indirections because they may be *string, etc. | |
func (g *generator) print(indent int, str ...interface{}) { | |
g.WriteString(strings.Repeat(" ", 2*indent)) | |
for _, v := range str { | |
switch s := v.(type) { | |
case string: | |
g.WriteString(s) | |
case *string: | |
g.WriteString(*s) | |
case bool: | |
fmt.Fprintf(g, "%t", s) | |
case *bool: | |
fmt.Fprintf(g, "%t", *s) | |
case int: | |
fmt.Fprintf(g, "%d", s) | |
case int32: | |
fmt.Fprintf(g, "%d", s) | |
case *int32: | |
fmt.Fprintf(g, "%d", *s) | |
case *int64: | |
fmt.Fprintf(g, "%d", *s) | |
case float64: | |
fmt.Fprintf(g, "%g", s) | |
case *float64: | |
fmt.Fprintf(g, "%g", *s) | |
default: | |
g.Fail(fmt.Sprintf("unknown type in printer: %T", v)) | |
} | |
} | |
g.WriteString("\n") | |
} | |
func (g *generator) processFile(fd *descriptor.FileDescriptorProto, definitions map[string]bool, mm messageMap, em enumMap) (*plugin.CodeGeneratorResponse_File, error) { | |
fn := fd.GetName() | |
ns := namespace(fn) | |
needsWrite := false | |
// Iterate message and enum dependencies to build a map for an import statement. | |
importMap := make(map[string]map[string]bool) | |
for _, message := range fd.MessageType { | |
qn := qualifiedName(ns, message.GetName()) | |
if _, ok := definitions[qn]; ok { | |
needsWrite = true | |
md := mm[qn] | |
for dep, _ := range md.dependencies { | |
messageImportDef := mm[dep] | |
if messageImportDef != nil && messageImportDef.location != fn { | |
// This is a message dependency | |
s, ok := importMap[messageImportDef.location] | |
if !ok { | |
s = make(map[string]bool) | |
} | |
s[dep] = true | |
importMap[messageImportDef.location] = s | |
continue | |
} | |
enumImportDef := em[dep] | |
if enumImportDef != nil && enumImportDef.location != fn { | |
// This is an enum dependency | |
s, ok := importMap[enumImportDef.location] | |
if !ok { | |
s = make(map[string]bool) | |
} | |
s[dep] = true | |
importMap[enumImportDef.location] = s | |
} | |
} | |
} | |
} | |
if !needsWrite { | |
// We didn't have | |
for _, enum := range fd.EnumType { | |
qn := qualifiedName(ns, enum.GetName()) | |
if _, ok := definitions[qn]; ok { | |
needsWrite = true | |
break | |
} | |
} | |
} | |
if !needsWrite { | |
// There's nothing for us to do for this file. | |
return nil, nil | |
} | |
i := 0 | |
// Write the import statements, if necessary. | |
multiImports := make([]string, 0) | |
singleImports := make([]string, 0) | |
for f, imports := range importMap { | |
loc := fmt.Sprintf("../%s", filename(f)) | |
loc = loc[:strings.LastIndex(loc, ".")] | |
ims := make([]string, 0) | |
for imp, _ := range imports { | |
ims = append(ims, strings.Trim(path.Ext(imp), ".")) | |
} | |
sort.Strings(ims) | |
statement := fmt.Sprintf("import { %s } from '%s'", strings.Join(ims, ", "), loc) | |
if len(ims) > 1 { | |
multiImports = append(multiImports, statement) | |
} else { | |
singleImports = append(singleImports, statement) | |
} | |
i++ | |
} | |
sort.Strings(multiImports) | |
sort.Strings(singleImports) | |
for _, is := range multiImports { | |
g.print(0, is) | |
} | |
for _, is := range singleImports { | |
g.print(0, is) | |
} | |
for _, enum := range fd.EnumType { | |
qn := qualifiedName(ns, enum.GetName()) | |
if _, ok := definitions[qn]; !ok { | |
continue | |
} | |
// This enum is used by at least one other message or service, so we need to write it out. | |
if i != 0 { | |
g.print(0, "") | |
} | |
g.print(0, "export enum ", enum.GetName(), " {") | |
for _, v := range enum.GetValue() { | |
g.print(1, v.GetName(), " = ", v.GetNumber(), ",") | |
} | |
g.print(0, "}") | |
i++ | |
} | |
for _, message := range fd.MessageType { | |
qn := qualifiedName(ns, message.GetName()) | |
if _, ok := definitions[qn]; !ok { | |
continue | |
} | |
// This message is used by at least one other message or service, so we need to write it out. | |
if i != 0 { | |
g.print(0, "") | |
} | |
md := mm[qn] | |
fieldCount := len(md.fields) | |
brackets := "{" | |
if fieldCount == 0 { | |
brackets = "{}" | |
} | |
g.print(0, "export type ", md.name, " = ", brackets) | |
for _, field := range md.fields { | |
g.print(1, field.name, ": ", field.t) | |
} | |
if fieldCount > 0 { | |
g.print(0, "}") | |
} | |
i++ | |
} | |
file := &plugin.CodeGeneratorResponse_File{ | |
Name: proto.String(filename(fn)), | |
Content: proto.String(g.String()), | |
} | |
g.Reset() | |
return file, nil | |
} | |
// parseField parses the supplied field to extract its type. If it is a map, enum, or | |
// message type, adds dependent messages and enums to the supplied dependencies map. | |
func parseField(message *descriptor.DescriptorProto, field *descriptor.FieldDescriptorProto, dependencies map[string]bool) string { | |
isMap := false | |
t := "" | |
switch field.GetType() { | |
case descriptor.FieldDescriptorProto_TYPE_INT32: | |
t = "number" | |
case descriptor.FieldDescriptorProto_TYPE_INT64: | |
t = "string" | |
case descriptor.FieldDescriptorProto_TYPE_UINT64: | |
t = "string" | |
case descriptor.FieldDescriptorProto_TYPE_DOUBLE: | |
t = "number" | |
case descriptor.FieldDescriptorProto_TYPE_FLOAT: | |
t = "number" | |
case descriptor.FieldDescriptorProto_TYPE_FIXED64: | |
t = "string" | |
case descriptor.FieldDescriptorProto_TYPE_FIXED32: | |
t = "number" | |
case descriptor.FieldDescriptorProto_TYPE_BOOL: | |
t = "boolean" | |
case descriptor.FieldDescriptorProto_TYPE_STRING: | |
t = "string" | |
case descriptor.FieldDescriptorProto_TYPE_BYTES: | |
t = "string" | |
case descriptor.FieldDescriptorProto_TYPE_MESSAGE: | |
messageType := strings.Trim(path.Ext(field.GetTypeName()), ".") | |
// Handle Maps | |
for _, nested := range message.GetNestedType() { | |
if nested.GetName() == messageType && nested.GetOptions().GetMapEntry() { | |
// This is a map, not a message. | |
isMap = true | |
keyType := "" | |
valueType := "" | |
for _, nestedField := range nested.GetField() { | |
if nestedField.GetName() == "value" { | |
valueType = parseField(message, nestedField, dependencies) | |
} else if nestedField.GetName() == "key" { | |
keyType = parseField(message, nestedField, dependencies) | |
} | |
} | |
messageType = fmt.Sprintf("{ [key: %s]: %s }", keyType, valueType) | |
break | |
} | |
} | |
if !isMap { | |
dependencies[field.GetTypeName()] = true | |
} | |
t = messageType | |
case descriptor.FieldDescriptorProto_TYPE_ENUM: | |
t = strings.Trim(path.Ext(field.GetTypeName()), ".") | |
dependencies[field.GetTypeName()] = true | |
default: | |
fatal(fmt.Errorf("Unknown field type %d for %+v in %s", field.GetType(), field, message.GetName()), "Error") | |
} | |
if !isMap && field.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED { | |
t = fmt.Sprintf("%s[]", t) | |
} | |
return t | |
} | |
func (g *generator) collectDefinitions(request *plugin.CodeGeneratorRequest) (map[string]bool, messageMap, enumMap) { | |
mm := make(messageMap) | |
sm := make(serviceMap) | |
em := make(enumMap) | |
for _, fd := range request.ProtoFile { | |
ns := namespace(fd.GetName()) | |
for _, enum := range fd.EnumType { | |
em[qualifiedName(ns, enum.GetName())] = &enumDefinition{ | |
desc: enum, | |
location: fd.GetName(), | |
} | |
} | |
for _, message := range fd.MessageType { | |
dependencies := make(map[string]bool) | |
fields := []messageField{} | |
for _, field := range message.GetField() { | |
fields = append(fields, messageField{ | |
name: field.GetJsonName(), | |
t: parseField(message, field, dependencies), | |
}) | |
} | |
mm[qualifiedName(ns, message.GetName())] = &messageDefinition{ | |
desc: message, | |
dependencies: dependencies, | |
name: message.GetName(), | |
fields: fields, | |
location: fd.GetName(), | |
} | |
} | |
for _, service := range fd.Service { | |
serviceName := gen.CamelCase(service.GetName()) | |
method := service.GetMethod() | |
dependencies := make(map[string]bool) | |
for _, m := range method { | |
dependencies[m.GetInputType()] = true | |
dependencies[m.GetOutputType()] = true | |
} | |
sm[serviceName] = &serviceDefinition{ | |
desc: service, | |
dependencies: dependencies, | |
} | |
} | |
} | |
// Iterate through everything we've found, starting from our services, so we | |
// only write types for things we've defined or imported. | |
collected := make(map[string]bool) | |
seen := make(map[string]bool) | |
for _, definition := range sm { | |
for sd, _ := range definition.dependencies { | |
// These definitions are Requests and Responses | |
collected[sd] = true | |
collectDependencies(sd, mm, collected, seen) | |
} | |
} | |
return collected, mm, em | |
} | |
func collectDependencies(key string, mm messageMap, collected map[string]bool, seen map[string]bool) { | |
if _, repeated := seen[key]; repeated { | |
return | |
} | |
seen[key] = true | |
if next, ok := mm[key]; ok { | |
for dependency, _ := range next.dependencies { | |
collected[dependency] = true | |
collectDependencies(dependency, mm, collected, seen) | |
} | |
} | |
} | |
func filename(name string) string { | |
if ext := path.Ext(name); ext == ".proto" { | |
name = name[:len(name)-len(ext)] | |
} | |
return fmt.Sprintf("%s_gw.ts", name) | |
} | |
func (g *generator) Generate() { | |
input, err := ioutil.ReadAll(g.reader) | |
if err != nil { | |
fatal(err, "Could not read input.") | |
} | |
request := g.Request | |
if err := proto.Unmarshal(input, request); err != nil { | |
fatal(err, "Could not parse input proto.") | |
} | |
if len(request.FileToGenerate) == 0 { | |
fatal(err, "No input files.") | |
} | |
g.CommandLineParameters(g.Request.GetParameter()) | |
g.WrapTypes() | |
g.SetPackageNames() | |
g.BuildTypeNameMap() | |
g.GenerateAllFiles() | |
g.Reset() | |
response := new(plugin.CodeGeneratorResponse) | |
messages, mm, em := g.collectDefinitions(request) | |
for _, fd := range request.ProtoFile { | |
file, err := g.processFile(fd, messages, mm, em) | |
if err != nil { | |
fatal(err, fmt.Sprintf("Couldn't write file for %s", fd.GetName())) | |
} | |
if file != nil { | |
response.File = append(response.File, file) | |
} | |
} | |
output, err := proto.Marshal(response) | |
if err != nil { | |
fatal(err, "Couldn't marshal output proto") | |
} | |
_, err = g.writer.Write(output) | |
if err != nil { | |
fatal(err, "Couldn't write files") | |
} | |
} | |
func main() { | |
g := New() | |
g.Generate() | |
} |