/* Given the protobuf below, generates the following funcs so that the message can be used as bitflags: func (this *User) UInt64() uint64 { b := uint64(0) if this.ScopeA { b |= uint64(1) << uint64(0) } if this.ScopeB { b |= uint64(1) << uint64(1) } if this.ScopeC { b |= uint64(1) << uint64(2) } return b } func (this *User) HighFlags() []string { var b []string if this.ScopeA { b = append(b, "scope_a") } if this.ScopeB { b = append(b, "scope_b") } if this.ScopeC { b = append(b, "scope_c") } return b } func (this *User) LowFlags() []string { var b []string if !this.ScopeA { b = append(b, "scope_a") } if !this.ScopeB { b = append(b, "scope_b") } if !this.ScopeC { b = append(b, "scope_c") } return b } func (this *User) SetFlag(flag string) error { switch flag { case "scope_a": this.ScopeA = true case "scope_b": this.ScopeB = true case "scope_c": this.ScopeC = true default: return fmt.Errorf("invalid flag: %v", flag) } return nil } func (this *User) ClearFlag(flag string) error { switch flag { case "scope_a": this.ScopeA = false case "scope_b": this.ScopeB = false case "scope_c": this.ScopeC = false default: return fmt.Errorf("invalid flag: %v", flag) } return nil } func (this *User) SetFlags(flags ...string) []error { var errs []error for _, f := range flags { if err := this.SetFlag(f); err != nil { errs = append(errs, err) } } return errs } func (this *User) ClearFlags(flags ...string) []error { var errs []error for _, f := range flags { if err := this.ClearFlag(f); err != nil { errs = append(errs, err) } } return errs } func (this *User) TestFlag(flag string) bool { switch flag { case "scope_a": return this.ScopeA case "scope_b": return this.ScopeB case "scope_c": return this.ScopeC } return false } func (this *User) TestFlags(flags ...string) bool { for _, f := range flags { if !this.TestFlag(f) { return false } } return true } func (this *User) Scan(i interface{}) error { switch v := i.(type) { case int: return this.FromUInt64(uint64(v)) case int32: return this.FromUInt64(uint64(v)) case float32: return this.FromUInt64(uint64(v)) case float64: return this.FromUInt64(uint64(v)) } return fmt.Errorf("invalid type: %T", i) } func (this *User) FromUInt64(b uint64) error { bb := b bb = b if bb&(uint64(1)<<uint(0)) > 0 { this.ScopeA = true } else { this.ScopeA = false } bb = b if bb&(uint64(1)<<uint(1)) > 0 { this.ScopeB = true } else { this.ScopeB = false } bb = b if bb&(uint64(1)<<uint(2)) > 0 { this.ScopeC = true } else { this.ScopeC = false } return nil } */ // Example protobuf /* syntax = "proto3"; import "github.com/gogo/protobuf/gogoproto/gogo.proto"; import "github.com/bitflags/bitflag.proto"; package flavortown.flags; message User { option (bitflagproto.bitflags) = true; bool scopeA = 1; bool scopeB = 2; bool scopeC = 3; } */ package bitflags import ( "bytes" "unicode" "github.com/gogo/protobuf/protoc-gen-gogo/generator" "github.com/bitflags/protobuf/bitflagproto" ) type plugin struct { *generator.Generator generator.PluginImports messages []*generator.Descriptor } func NewBitflags() *plugin { return &plugin{} } func (p *plugin) Name() string { return "bitflags" } func (p *plugin) Init(g *generator.Generator) { p.Generator = g } func (p *plugin) Generate(file *generator.FileDescriptor) { p.PluginImports = generator.NewPluginImports(p.Generator) p.messages = make([]*generator.Descriptor, 0) for _, message := range file.Messages() { if !bitflagproto.IsBitflags(file.FileDescriptorProto, message.DescriptorProto) { continue } p.messages = append(p.messages, message) baseTypeName := generator.CamelCaseSlice(message.TypeName()) // UInt64() // returns a bitflags uint64 representation of the structure p.P(`func (this *`, baseTypeName, `) UInt64() uint64 {`) p.In() p.P(`b := uint64(0)`) for bit, field := range message.Field { fieldname := p.GetFieldName(message, field) p.P(`if this.`, fieldname, ` {`) p.In() p.P(`b |= uint64(1) << uint64(`, bit, `)`) p.Out() p.P(`}`) } p.P() p.P(`return b`) p.P(`}`) // HighFlags() returns fields in struct set to 1 p.P(`func (this *`, baseTypeName, `) HighFlags() []string {`) p.In() p.P(`var b []string`) for _, field := range message.Field { fieldname := p.GetFieldName(message, field) p.P(`if this.`, fieldname, ` {`) p.In() p.P(`b = append(b, "`, snakeCase(fieldname), `")`) p.Out() p.P(`}`) } p.P(`return b`) p.P(`}`) p.P() // LowFlags() returns fields in struct set to 0 p.P(`func (this *`, baseTypeName, `) LowFlags() []string {`) p.In() p.P(`var b []string`) for _, field := range message.Field { fieldname := p.GetFieldName(message, field) p.P(`if !this.`, fieldname, ` {`) p.In() p.P(`b = append(b, "`, snakeCase(fieldname), `")`) p.Out() p.P(`}`) } p.P(`return b`) p.P(`}`) p.P() // Sets a flag or returns error p.P(`func (this *`, baseTypeName, `) SetFlag(flag string) error {`) p.In() p.P(`switch flag {`) p.In() for _, field := range message.Field { fieldname := p.GetFieldName(message, field) p.P(`case "`, snakeCase(fieldname), `":`) p.In() p.P(`this.`, fieldname, `= true `) p.Out() } p.P(`default:`) p.In() p.P(`return fmt.Errorf("invalid flag: %v", flag)`) p.Out() p.P(`}`) p.Out() p.P(`return nil`) p.P(`}`) // Sets a flag or returns error p.P(`func (this *`, baseTypeName, `) ClearFlag(flag string) error {`) p.In() p.P(`switch flag {`) p.In() for _, field := range message.Field { fieldname := p.GetFieldName(message, field) p.P(`case "`, snakeCase(fieldname), `":`) p.In() p.P(`this.`, fieldname, `= false `) p.Out() } p.P(`default:`) p.In() p.P(`return fmt.Errorf("invalid flag: %v", flag)`) p.Out() p.P(`}`) p.Out() p.P(`return nil`) p.P(`}`) // SetFlags(...string) []error // sets a number of flags which correspond to fields in the struct p.P(`func (this *`, baseTypeName, `) SetFlags(flags ...string) []error {`) p.In() p.P(`var errs []error`) p.P(`for _, f := range flags {`) p.In() p.P(`if err := this.SetFlag(f); err != nil {`) p.In() p.P(`errs = append(errs, err)`) p.Out() p.P(`}`) p.Out() p.P(`}`) p.P(`return errs`) p.Out() p.P(`}`) // ClearFlags(...string) []error // sets a number of flags which correspond to fields in the struct p.P(`func (this *`, baseTypeName, `) ClearFlags(flags ...string) []error {`) p.In() p.P(`var errs []error`) p.P(`for _, f := range flags {`) p.In() p.P(`if err := this.ClearFlag(f); err != nil {`) p.In() p.P(`errs = append(errs, err)`) p.Out() p.P(`}`) p.Out() p.P(`}`) p.P(`return errs`) p.Out() p.P(`}`) // Returns the value of a flag or false if the flag does not exist p.P(`func (this *`, baseTypeName, `) TestFlag(flag string) bool {`) p.In() p.P(`switch flag {`) p.In() for _, field := range message.Field { fieldname := p.GetFieldName(message, field) p.P(`case "`, snakeCase(fieldname), `":`) p.In() p.P(`return this.`, fieldname) p.Out() } p.P(`}`) p.Out() p.P(`return false`) p.P(`}`) // TestFlags(...string) []error // Returns Flag1 AND Flag2 AND ... p.P(`func (this *`, baseTypeName, `) TestFlags(flags ...string) bool {`) p.In() p.P(`for _, f := range flags {`) p.In() p.P(`if !this.TestFlag(f) {`) p.In() p.P(`return false`) p.Out() p.P(`}`) p.Out() p.P(`}`) p.P(`return true`) p.Out() p.P(`}`) // Represent in the database as int64() // returns a bitflags uint64 representation of the structure p.P(`func (this *`, baseTypeName, `) Scan(i interface{}) error {`) p.In() p.P(`switch v := i.(type) {`) types := []string{"int", "int32", "float32", "float64"} for _, t := range types { p.P(`case `, t, `:`) p.In() p.P(`return this.FromUInt64(uint64(v))`) p.Out() } p.P(`}`) p.P() p.P(`return fmt.Errorf("invalid type: %T", i)`) p.P(`}`) // FromInt64(b uint64) // TODO(dan) should return error if overflow p.P(`func (this *`, baseTypeName, `) FromUInt64(b uint64) error {`) p.In() p.P(`bb := b`) for i, field := range message.Field { p.P(`bb = b`) fieldname := p.GetFieldName(message, field) p.P(`if bb&(uint64(1)<<uint(`, i, `)) > 0 {`) p.In() p.P(`this.`, fieldname, ` = true`) p.Out() p.P(`} else {`) p.In() p.P(`this.`, fieldname, ` = false`) p.P(`}`) p.Out() } p.P() p.P(`return nil`) p.P(`}`) } } func snakeCase(in string) string { runes := []rune(in) length := len(runes) out := bytes.NewBuffer(make([]byte, 0, length)) for i := 0; i < length; i++ { if i > 0 && unicode.IsUpper(runes[i]) && ((i+1 < length && unicode.IsLower(runes[i+1])) || unicode.IsLower(runes[i-1])) { out.WriteRune('_') } out.WriteRune(unicode.ToLower(runes[i])) } return out.String() } func init() { generator.RegisterPlugin(NewBitflags()) }