Created
December 2, 2019 17:16
-
-
Save dancompton/908416a9f4fe2c38f688f25355b770d1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package bitflags | |
import ( | |
"bytes" | |
"unicode" | |
"github.com/gogo/protobuf/protoc-gen-gogo/generator" | |
"github.com/opsee/protobuf/opseeproto" | |
) | |
type plugin struct { | |
*generator.Generator | |
generator.PluginImports | |
messages []*generator.Descriptor | |
} | |
func NewBitflags() *plugin { | |
return &plugin{} | |
} | |
func (p *plugin) Name() string { | |
return "bitflags" | |
} | |
func (p *plugin) Init(g *generator.Generator) { | |
p.Generator = g | |
} | |
func (p *plugin) Generate(file *generator.FileDescriptor) { | |
p.PluginImports = generator.NewPluginImports(p.Generator) | |
p.messages = make([]*generator.Descriptor, 0) | |
driverPkg := p.NewImport("database/sql/driver") | |
for _, message := range file.Messages() { | |
if !opseeproto.IsBitflags(file.FileDescriptorProto, message.DescriptorProto) { | |
continue | |
} | |
p.messages = append(p.messages, message) | |
baseTypeName := generator.CamelCaseSlice(message.TypeName()) | |
// UInt64() | |
// returns a bitflags uint64 representation of the structure | |
p.P(`func (this *`, baseTypeName, `) UInt64() uint64 {`) | |
p.In() | |
p.P(`b := uint64(0)`) | |
for bit, field := range message.Field { | |
fieldname := p.GetFieldName(message, field) | |
p.P(`if this.`, fieldname, ` {`) | |
p.In() | |
p.P(`b |= uint64(1) << uint64(`, bit, `)`) | |
p.Out() | |
p.P(`}`) | |
} | |
p.P() | |
p.P(`return b`) | |
p.P(`}`) | |
// HighFlags() returns fields in struct set to 1 | |
p.P(`func (this *`, baseTypeName, `) HighFlags() []string {`) | |
p.In() | |
p.P(`var b []string`) | |
for _, field := range message.Field { | |
fieldname := p.GetFieldName(message, field) | |
p.P(`if this.`, fieldname, ` {`) | |
p.In() | |
p.P(`b = append(b, "`, snakeCase(fieldname), `")`) | |
p.Out() | |
p.P(`}`) | |
} | |
p.P(`return b`) | |
p.P(`}`) | |
p.P() | |
// LowFlags() returns fields in struct set to 0 | |
p.P(`func (this *`, baseTypeName, `) LowFlags() []string {`) | |
p.In() | |
p.P(`var b []string`) | |
for _, field := range message.Field { | |
fieldname := p.GetFieldName(message, field) | |
p.P(`if !this.`, fieldname, ` {`) | |
p.In() | |
p.P(`b = append(b, "`, snakeCase(fieldname), `")`) | |
p.Out() | |
p.P(`}`) | |
} | |
p.P(`return b`) | |
p.P(`}`) | |
p.P() | |
// Sets a flag or returns error | |
p.P(`func (this *`, baseTypeName, `) SetFlag(flag string) error {`) | |
p.In() | |
p.P(`switch flag {`) | |
p.In() | |
for _, field := range message.Field { | |
fieldname := p.GetFieldName(message, field) | |
p.P(`case "`, snakeCase(fieldname), `":`) | |
p.In() | |
p.P(`this.`, fieldname, `= true `) | |
p.Out() | |
} | |
p.P(`default:`) | |
p.In() | |
p.P(`return fmt.Errorf("invalid flag: %v", flag)`) | |
p.Out() | |
p.P(`}`) | |
p.Out() | |
p.P(`return nil`) | |
p.P(`}`) | |
// Sets a flag or returns error | |
p.P(`func (this *`, baseTypeName, `) ClearFlag(flag string) error {`) | |
p.In() | |
p.P(`switch flag {`) | |
p.In() | |
for _, field := range message.Field { | |
fieldname := p.GetFieldName(message, field) | |
p.P(`case "`, snakeCase(fieldname), `":`) | |
p.In() | |
p.P(`this.`, fieldname, `= false `) | |
p.Out() | |
} | |
p.P(`default:`) | |
p.In() | |
p.P(`return fmt.Errorf("invalid flag: %v", flag)`) | |
p.Out() | |
p.P(`}`) | |
p.Out() | |
p.P(`return nil`) | |
p.P(`}`) | |
// SetFlags(...string) []error | |
// sets a number of flags which correspond to fields in the struct | |
p.P(`func (this *`, baseTypeName, `) SetFlags(flags ...string) []error {`) | |
p.In() | |
p.P(`var errs []error`) | |
p.P(`for _, f := range flags {`) | |
p.In() | |
p.P(`if err := this.SetFlag(f); err != nil {`) | |
p.In() | |
p.P(`errs = append(errs, err)`) | |
p.Out() | |
p.P(`}`) | |
p.Out() | |
p.P(`}`) | |
p.P(`return errs`) | |
p.Out() | |
p.P(`}`) | |
// ClearFlags(...string) []error | |
// sets a number of flags which correspond to fields in the struct | |
p.P(`func (this *`, baseTypeName, `) ClearFlags(flags ...string) []error {`) | |
p.In() | |
p.P(`var errs []error`) | |
p.P(`for _, f := range flags {`) | |
p.In() | |
p.P(`if err := this.ClearFlag(f); err != nil {`) | |
p.In() | |
p.P(`errs = append(errs, err)`) | |
p.Out() | |
p.P(`}`) | |
p.Out() | |
p.P(`}`) | |
p.P(`return errs`) | |
p.Out() | |
p.P(`}`) | |
// Returns the value of a flag or false if the flag does not exist | |
p.P(`func (this *`, baseTypeName, `) TestFlag(flag string) bool {`) | |
p.In() | |
p.P(`switch flag {`) | |
p.In() | |
for _, field := range message.Field { | |
fieldname := p.GetFieldName(message, field) | |
p.P(`case "`, snakeCase(fieldname), `":`) | |
p.In() | |
p.P(`return this.`, fieldname) | |
p.Out() | |
} | |
p.P(`}`) | |
p.Out() | |
p.P(`return false`) | |
p.P(`}`) | |
// TestFlags(...string) []error | |
// Returns Flag1 AND Flag2 AND ... | |
p.P(`func (this *`, baseTypeName, `) TestFlags(flags ...string) bool {`) | |
p.In() | |
p.P(`for _, f := range flags {`) | |
p.In() | |
p.P(`if !this.TestFlag(f) {`) | |
p.In() | |
p.P(`return false`) | |
p.Out() | |
p.P(`}`) | |
p.Out() | |
p.P(`}`) | |
p.P(`return true`) | |
p.Out() | |
p.P(`}`) | |
// FromInt64(b uint64) | |
// TODO(dan) should return error if overflow | |
p.P(`func (this *`, baseTypeName, `) FromUInt64(b uint64) error {`) | |
p.In() | |
p.P(`bb := b`) | |
for i, field := range message.Field { | |
p.P(`bb = b`) | |
fieldname := p.GetFieldName(message, field) | |
p.P(`if bb&(uint64(1)<<uint(`, i, `)) > 0 {`) | |
p.In() | |
p.P(`this.`, fieldname, ` = true`) | |
p.Out() | |
p.P(`} else {`) | |
p.In() | |
p.P(`this.`, fieldname, ` = false`) | |
p.P(`}`) | |
p.Out() | |
} | |
p.P() | |
p.P(`return nil`) | |
p.P(`}`) | |
// returns a bitflags uint64 representation of the structure from database | |
p.P(`func (this *`, baseTypeName, `) Scan(i interface{}) error {`) | |
p.In() | |
p.P(`switch v := i.(type) {`) | |
types := []string{"int", "int32", "int64", "float32", "float64"} | |
for _, t := range types { | |
p.P(`case `, t, `:`) | |
p.In() | |
p.P(`return this.FromUInt64(uint64(v))`) | |
p.Out() | |
} | |
p.P(`}`) | |
p.P() | |
p.P(`return fmt.Errorf("invalid type: %T", i)`) | |
p.P(`}`) | |
// Represent in the database as int64() | |
// allows the type to be encoded by database | |
p.P(`func (this *`, baseTypeName, `) Value() (`, driverPkg.Use(), `.Value, error) {`) | |
p.In() | |
p.P(`return int64(this.UInt64()), nil`) | |
p.P(`}`) | |
} | |
} | |
func snakeCase(in string) string { | |
runes := []rune(in) | |
length := len(runes) | |
out := bytes.NewBuffer(make([]byte, 0, length)) | |
for i := 0; i < length; i++ { | |
if i > 0 && unicode.IsUpper(runes[i]) && ((i+1 < length && unicode.IsLower(runes[i+1])) || unicode.IsLower(runes[i-1])) { | |
out.WriteRune('_') | |
} | |
out.WriteRune(unicode.ToLower(runes[i])) | |
} | |
return out.String() | |
} | |
func init() { | |
generator.RegisterPlugin(NewBitflags()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment