Skip to content

Instantly share code, notes, and snippets.

@majelbstoat
Created March 23, 2020 00:17
Show Gist options
  • Save majelbstoat/5f73f58a057123c8d46db66ebbaae987 to your computer and use it in GitHub Desktop.
Save majelbstoat/5f73f58a057123c8d46db66ebbaae987 to your computer and use it in GitHub Desktop.
protoc-gen-grpc-gateway-ts

Example Usage:

$ go build -i -o protoc-gen-grpc-gateway-ts .

$ protoc -I ./lib/proto ./lib/proto/*.proto
--plugin=protoc-gen-grpc_gateway_ts=which protoc-gen-grpc-gateway-ts
--grpc_gateway_ts_out=lib/ts/src

package main
import (
"fmt"
"io"
"io/ioutil"
"log"
"os"
"path"
"sort"
"strings"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/protoc-gen-go/descriptor"
gen "github.com/golang/protobuf/protoc-gen-go/generator"
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
)
type generator struct {
*gen.Generator
reader io.Reader
writer io.Writer
}
type messageField struct {
name string
t string
}
func namespace(name string) string {
return fmt.Sprintf(".%s", strings.Replace(path.Dir(name), "/", ".", -1))
}
func qualifiedName(ns, name string) string {
return fmt.Sprintf("%s.%s", ns, name)
}
type serviceDefinition struct {
desc *descriptor.ServiceDescriptorProto
dependencies map[string]bool
}
type serviceMap map[string]*serviceDefinition
type messageDefinition struct {
desc *descriptor.DescriptorProto
dependencies map[string]bool
name string
fields []messageField
location string
}
type messageMap map[string]*messageDefinition
type enumDefinition struct {
desc *descriptor.EnumDescriptorProto
location string
}
type enumMap map[string]*enumDefinition
func fatal(err error, msg string) {
log.Printf("protoc-gen-grpc-gateway-ts: error: %s: %s", msg, err.Error())
os.Exit(1)
}
func New() *generator {
return &generator{
Generator: gen.New(),
reader: os.Stdin,
writer: os.Stdout,
}
}
// P prints the arguments to the generated output. It handles strings and int32s, plus
// handling indirections because they may be *string, etc.
func (g *generator) print(indent int, str ...interface{}) {
g.WriteString(strings.Repeat(" ", 2*indent))
for _, v := range str {
switch s := v.(type) {
case string:
g.WriteString(s)
case *string:
g.WriteString(*s)
case bool:
fmt.Fprintf(g, "%t", s)
case *bool:
fmt.Fprintf(g, "%t", *s)
case int:
fmt.Fprintf(g, "%d", s)
case int32:
fmt.Fprintf(g, "%d", s)
case *int32:
fmt.Fprintf(g, "%d", *s)
case *int64:
fmt.Fprintf(g, "%d", *s)
case float64:
fmt.Fprintf(g, "%g", s)
case *float64:
fmt.Fprintf(g, "%g", *s)
default:
g.Fail(fmt.Sprintf("unknown type in printer: %T", v))
}
}
g.WriteString("\n")
}
func (g *generator) processFile(fd *descriptor.FileDescriptorProto, definitions map[string]bool, mm messageMap, em enumMap) (*plugin.CodeGeneratorResponse_File, error) {
fn := fd.GetName()
ns := namespace(fn)
needsWrite := false
// Iterate message and enum dependencies to build a map for an import statement.
importMap := make(map[string]map[string]bool)
for _, message := range fd.MessageType {
qn := qualifiedName(ns, message.GetName())
if _, ok := definitions[qn]; ok {
needsWrite = true
md := mm[qn]
for dep, _ := range md.dependencies {
messageImportDef := mm[dep]
if messageImportDef != nil && messageImportDef.location != fn {
// This is a message dependency
s, ok := importMap[messageImportDef.location]
if !ok {
s = make(map[string]bool)
}
s[dep] = true
importMap[messageImportDef.location] = s
continue
}
enumImportDef := em[dep]
if enumImportDef != nil && enumImportDef.location != fn {
// This is an enum dependency
s, ok := importMap[enumImportDef.location]
if !ok {
s = make(map[string]bool)
}
s[dep] = true
importMap[enumImportDef.location] = s
}
}
}
}
if !needsWrite {
// We didn't have
for _, enum := range fd.EnumType {
qn := qualifiedName(ns, enum.GetName())
if _, ok := definitions[qn]; ok {
needsWrite = true
break
}
}
}
if !needsWrite {
// There's nothing for us to do for this file.
return nil, nil
}
i := 0
// Write the import statements, if necessary.
multiImports := make([]string, 0)
singleImports := make([]string, 0)
for f, imports := range importMap {
loc := fmt.Sprintf("../%s", filename(f))
loc = loc[:strings.LastIndex(loc, ".")]
ims := make([]string, 0)
for imp, _ := range imports {
ims = append(ims, strings.Trim(path.Ext(imp), "."))
}
sort.Strings(ims)
statement := fmt.Sprintf("import { %s } from '%s'", strings.Join(ims, ", "), loc)
if len(ims) > 1 {
multiImports = append(multiImports, statement)
} else {
singleImports = append(singleImports, statement)
}
i++
}
sort.Strings(multiImports)
sort.Strings(singleImports)
for _, is := range multiImports {
g.print(0, is)
}
for _, is := range singleImports {
g.print(0, is)
}
for _, enum := range fd.EnumType {
qn := qualifiedName(ns, enum.GetName())
if _, ok := definitions[qn]; !ok {
continue
}
// This enum is used by at least one other message or service, so we need to write it out.
if i != 0 {
g.print(0, "")
}
g.print(0, "export enum ", enum.GetName(), " {")
for _, v := range enum.GetValue() {
g.print(1, v.GetName(), " = ", v.GetNumber(), ",")
}
g.print(0, "}")
i++
}
for _, message := range fd.MessageType {
qn := qualifiedName(ns, message.GetName())
if _, ok := definitions[qn]; !ok {
continue
}
// This message is used by at least one other message or service, so we need to write it out.
if i != 0 {
g.print(0, "")
}
md := mm[qn]
fieldCount := len(md.fields)
brackets := "{"
if fieldCount == 0 {
brackets = "{}"
}
g.print(0, "export type ", md.name, " = ", brackets)
for _, field := range md.fields {
g.print(1, field.name, ": ", field.t)
}
if fieldCount > 0 {
g.print(0, "}")
}
i++
}
file := &plugin.CodeGeneratorResponse_File{
Name: proto.String(filename(fn)),
Content: proto.String(g.String()),
}
g.Reset()
return file, nil
}
// parseField parses the supplied field to extract its type. If it is a map, enum, or
// message type, adds dependent messages and enums to the supplied dependencies map.
func parseField(message *descriptor.DescriptorProto, field *descriptor.FieldDescriptorProto, dependencies map[string]bool) string {
isMap := false
t := ""
switch field.GetType() {
case descriptor.FieldDescriptorProto_TYPE_INT32:
t = "number"
case descriptor.FieldDescriptorProto_TYPE_INT64:
t = "string"
case descriptor.FieldDescriptorProto_TYPE_UINT64:
t = "string"
case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
t = "number"
case descriptor.FieldDescriptorProto_TYPE_FLOAT:
t = "number"
case descriptor.FieldDescriptorProto_TYPE_FIXED64:
t = "string"
case descriptor.FieldDescriptorProto_TYPE_FIXED32:
t = "number"
case descriptor.FieldDescriptorProto_TYPE_BOOL:
t = "boolean"
case descriptor.FieldDescriptorProto_TYPE_STRING:
t = "string"
case descriptor.FieldDescriptorProto_TYPE_BYTES:
t = "string"
case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
messageType := strings.Trim(path.Ext(field.GetTypeName()), ".")
// Handle Maps
for _, nested := range message.GetNestedType() {
if nested.GetName() == messageType && nested.GetOptions().GetMapEntry() {
// This is a map, not a message.
isMap = true
keyType := ""
valueType := ""
for _, nestedField := range nested.GetField() {
if nestedField.GetName() == "value" {
valueType = parseField(message, nestedField, dependencies)
} else if nestedField.GetName() == "key" {
keyType = parseField(message, nestedField, dependencies)
}
}
messageType = fmt.Sprintf("{ [key: %s]: %s }", keyType, valueType)
break
}
}
if !isMap {
dependencies[field.GetTypeName()] = true
}
t = messageType
case descriptor.FieldDescriptorProto_TYPE_ENUM:
t = strings.Trim(path.Ext(field.GetTypeName()), ".")
dependencies[field.GetTypeName()] = true
default:
fatal(fmt.Errorf("Unknown field type %d for %+v in %s", field.GetType(), field, message.GetName()), "Error")
}
if !isMap && field.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED {
t = fmt.Sprintf("%s[]", t)
}
return t
}
func (g *generator) collectDefinitions(request *plugin.CodeGeneratorRequest) (map[string]bool, messageMap, enumMap) {
mm := make(messageMap)
sm := make(serviceMap)
em := make(enumMap)
for _, fd := range request.ProtoFile {
ns := namespace(fd.GetName())
for _, enum := range fd.EnumType {
em[qualifiedName(ns, enum.GetName())] = &enumDefinition{
desc: enum,
location: fd.GetName(),
}
}
for _, message := range fd.MessageType {
dependencies := make(map[string]bool)
fields := []messageField{}
for _, field := range message.GetField() {
fields = append(fields, messageField{
name: field.GetJsonName(),
t: parseField(message, field, dependencies),
})
}
mm[qualifiedName(ns, message.GetName())] = &messageDefinition{
desc: message,
dependencies: dependencies,
name: message.GetName(),
fields: fields,
location: fd.GetName(),
}
}
for _, service := range fd.Service {
serviceName := gen.CamelCase(service.GetName())
method := service.GetMethod()
dependencies := make(map[string]bool)
for _, m := range method {
dependencies[m.GetInputType()] = true
dependencies[m.GetOutputType()] = true
}
sm[serviceName] = &serviceDefinition{
desc: service,
dependencies: dependencies,
}
}
}
// Iterate through everything we've found, starting from our services, so we
// only write types for things we've defined or imported.
collected := make(map[string]bool)
seen := make(map[string]bool)
for _, definition := range sm {
for sd, _ := range definition.dependencies {
// These definitions are Requests and Responses
collected[sd] = true
collectDependencies(sd, mm, collected, seen)
}
}
return collected, mm, em
}
func collectDependencies(key string, mm messageMap, collected map[string]bool, seen map[string]bool) {
if _, repeated := seen[key]; repeated {
return
}
seen[key] = true
if next, ok := mm[key]; ok {
for dependency, _ := range next.dependencies {
collected[dependency] = true
collectDependencies(dependency, mm, collected, seen)
}
}
}
func filename(name string) string {
if ext := path.Ext(name); ext == ".proto" {
name = name[:len(name)-len(ext)]
}
return fmt.Sprintf("%s_gw.ts", name)
}
func (g *generator) Generate() {
input, err := ioutil.ReadAll(g.reader)
if err != nil {
fatal(err, "Could not read input.")
}
request := g.Request
if err := proto.Unmarshal(input, request); err != nil {
fatal(err, "Could not parse input proto.")
}
if len(request.FileToGenerate) == 0 {
fatal(err, "No input files.")
}
g.CommandLineParameters(g.Request.GetParameter())
g.WrapTypes()
g.SetPackageNames()
g.BuildTypeNameMap()
g.GenerateAllFiles()
g.Reset()
response := new(plugin.CodeGeneratorResponse)
messages, mm, em := g.collectDefinitions(request)
for _, fd := range request.ProtoFile {
file, err := g.processFile(fd, messages, mm, em)
if err != nil {
fatal(err, fmt.Sprintf("Couldn't write file for %s", fd.GetName()))
}
if file != nil {
response.File = append(response.File, file)
}
}
output, err := proto.Marshal(response)
if err != nil {
fatal(err, "Couldn't marshal output proto")
}
_, err = g.writer.Write(output)
if err != nil {
fatal(err, "Couldn't write files")
}
}
func main() {
g := New()
g.Generate()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment