/*
Given the protobuf below, generates the following funcs so that the message can be used as bitflags: 

func (this *User) UInt64() uint64 {
	b := uint64(0)
	if this.ScopeA {
		b |= uint64(1) << uint64(0)
	}
	if this.ScopeB {
		b |= uint64(1) << uint64(1)
	}
	if this.ScopeC {
		b |= uint64(1) << uint64(2)
	}

	return b
}
func (this *User) HighFlags() []string {
	var b []string
	if this.ScopeA {
		b = append(b, "scope_a")
	}
	if this.ScopeB {
		b = append(b, "scope_b")
	}
	if this.ScopeC {
		b = append(b, "scope_c")
	}
	return b
}

func (this *User) LowFlags() []string {
	var b []string
	if !this.ScopeA {
		b = append(b, "scope_a")
	}
	if !this.ScopeB {
		b = append(b, "scope_b")
	}
	if !this.ScopeC {
		b = append(b, "scope_c")
	}
	return b
}

func (this *User) SetFlag(flag string) error {
	switch flag {
	case "scope_a":
		this.ScopeA = true
	case "scope_b":
		this.ScopeB = true
	case "scope_c":
		this.ScopeC = true
	default:
		return fmt.Errorf("invalid flag: %v", flag)
	}
	return nil
}
func (this *User) ClearFlag(flag string) error {
	switch flag {
	case "scope_a":
		this.ScopeA = false
	case "scope_b":
		this.ScopeB = false
	case "scope_c":
		this.ScopeC = false
	default:
		return fmt.Errorf("invalid flag: %v", flag)
	}
	return nil
}
func (this *User) SetFlags(flags ...string) []error {
	var errs []error
	for _, f := range flags {
		if err := this.SetFlag(f); err != nil {
			errs = append(errs, err)
		}
	}
	return errs
}
func (this *User) ClearFlags(flags ...string) []error {
	var errs []error
	for _, f := range flags {
		if err := this.ClearFlag(f); err != nil {
			errs = append(errs, err)
		}
	}
	return errs
}
func (this *User) TestFlag(flag string) bool {
	switch flag {
	case "scope_a":
		return this.ScopeA
	case "scope_b":
		return this.ScopeB
	case "scope_c":
		return this.ScopeC
	}
	return false
}
func (this *User) TestFlags(flags ...string) bool {
	for _, f := range flags {
		if !this.TestFlag(f) {
			return false
		}
	}
	return true
}
func (this *User) Scan(i interface{}) error {
	switch v := i.(type) {
	case int:
		return this.FromUInt64(uint64(v))
	case int32:
		return this.FromUInt64(uint64(v))
	case float32:
		return this.FromUInt64(uint64(v))
	case float64:
		return this.FromUInt64(uint64(v))
	}

	return fmt.Errorf("invalid type: %T", i)
}
func (this *User) FromUInt64(b uint64) error {
	bb := b
	bb = b
	if bb&(uint64(1)<<uint(0)) > 0 {
		this.ScopeA = true
	} else {
		this.ScopeA = false
	}
	bb = b
	if bb&(uint64(1)<<uint(1)) > 0 {
		this.ScopeB = true
	} else {
		this.ScopeB = false
	}
	bb = b
	if bb&(uint64(1)<<uint(2)) > 0 {
		this.ScopeC = true
	} else {
		this.ScopeC = false
	}

	return nil
}
*/

// Example protobuf
/*
syntax = "proto3";

import "github.com/gogo/protobuf/gogoproto/gogo.proto";
import "github.com/bitflags/bitflag.proto";

package flavortown.flags;

message User {
    option (bitflagproto.bitflags) = true;
    bool scopeA = 1;
    bool scopeB = 2;
    bool scopeC = 3;
}
*/

package bitflags

import (
	"bytes"
	"unicode"

	"github.com/gogo/protobuf/protoc-gen-gogo/generator"
	"github.com/bitflags/protobuf/bitflagproto"
)

type plugin struct {
	*generator.Generator
	generator.PluginImports
	messages []*generator.Descriptor
}

func NewBitflags() *plugin {
	return &plugin{}
}

func (p *plugin) Name() string {
	return "bitflags"
}

func (p *plugin) Init(g *generator.Generator) {
	p.Generator = g
}

func (p *plugin) Generate(file *generator.FileDescriptor) {
	p.PluginImports = generator.NewPluginImports(p.Generator)
	p.messages = make([]*generator.Descriptor, 0)

	for _, message := range file.Messages() {
		if !bitflagproto.IsBitflags(file.FileDescriptorProto, message.DescriptorProto) {
			continue
		}
		p.messages = append(p.messages, message)
		baseTypeName := generator.CamelCaseSlice(message.TypeName())

		// UInt64()
		// returns a bitflags uint64 representation of the structure
		p.P(`func (this *`, baseTypeName, `) UInt64() uint64 {`)
		p.In()
		p.P(`b := uint64(0)`)
		for bit, field := range message.Field {
			fieldname := p.GetFieldName(message, field)

			p.P(`if this.`, fieldname, ` {`)
			p.In()
			p.P(`b |= uint64(1) << uint64(`, bit, `)`)
			p.Out()
			p.P(`}`)
		}
		p.P()
		p.P(`return b`)
		p.P(`}`)

		// HighFlags() returns fields in struct set to 1
		p.P(`func (this *`, baseTypeName, `) HighFlags() []string {`)
		p.In()
		p.P(`var b []string`)
		for _, field := range message.Field {
			fieldname := p.GetFieldName(message, field)

			p.P(`if this.`, fieldname, ` {`)
			p.In()
			p.P(`b = append(b, "`, snakeCase(fieldname), `")`)
			p.Out()
			p.P(`}`)
		}
		p.P(`return b`)
		p.P(`}`)
		p.P()

		// LowFlags() returns fields in struct set to 0
		p.P(`func (this *`, baseTypeName, `) LowFlags() []string {`)
		p.In()
		p.P(`var b []string`)
		for _, field := range message.Field {
			fieldname := p.GetFieldName(message, field)

			p.P(`if !this.`, fieldname, ` {`)
			p.In()
			p.P(`b = append(b, "`, snakeCase(fieldname), `")`)
			p.Out()
			p.P(`}`)
		}
		p.P(`return b`)
		p.P(`}`)
		p.P()

		// Sets a flag or returns error
		p.P(`func (this *`, baseTypeName, `) SetFlag(flag string) error {`)
		p.In()
		p.P(`switch flag {`)
		p.In()
		for _, field := range message.Field {
			fieldname := p.GetFieldName(message, field)
			p.P(`case "`, snakeCase(fieldname), `":`)
			p.In()
			p.P(`this.`, fieldname, `= true `)
			p.Out()
		}
		p.P(`default:`)
		p.In()
		p.P(`return fmt.Errorf("invalid flag: %v", flag)`)
		p.Out()
		p.P(`}`)
		p.Out()
		p.P(`return nil`)
		p.P(`}`)

		// Sets a flag or returns error
		p.P(`func (this *`, baseTypeName, `) ClearFlag(flag string) error {`)
		p.In()
		p.P(`switch flag {`)
		p.In()
		for _, field := range message.Field {
			fieldname := p.GetFieldName(message, field)
			p.P(`case "`, snakeCase(fieldname), `":`)
			p.In()
			p.P(`this.`, fieldname, `= false `)
			p.Out()
		}
		p.P(`default:`)
		p.In()
		p.P(`return fmt.Errorf("invalid flag: %v", flag)`)
		p.Out()
		p.P(`}`)
		p.Out()
		p.P(`return nil`)
		p.P(`}`)

		// SetFlags(...string) []error
		// sets a number of flags which correspond to fields in the struct
		p.P(`func (this *`, baseTypeName, `) SetFlags(flags ...string) []error {`)
		p.In()
		p.P(`var errs []error`)
		p.P(`for _, f := range flags {`)
		p.In()
		p.P(`if err := this.SetFlag(f); err != nil {`)
		p.In()
		p.P(`errs = append(errs, err)`)
		p.Out()
		p.P(`}`)
		p.Out()
		p.P(`}`)
		p.P(`return errs`)
		p.Out()
		p.P(`}`)

		// ClearFlags(...string) []error
		// sets a number of flags which correspond to fields in the struct
		p.P(`func (this *`, baseTypeName, `) ClearFlags(flags ...string) []error {`)
		p.In()
		p.P(`var errs []error`)
		p.P(`for _, f := range flags {`)
		p.In()
		p.P(`if err := this.ClearFlag(f); err != nil {`)
		p.In()
		p.P(`errs = append(errs, err)`)
		p.Out()
		p.P(`}`)
		p.Out()
		p.P(`}`)
		p.P(`return errs`)
		p.Out()
		p.P(`}`)

		// Returns the value of a flag or false if the flag does not exist
		p.P(`func (this *`, baseTypeName, `) TestFlag(flag string) bool {`)
		p.In()
		p.P(`switch flag {`)
		p.In()
		for _, field := range message.Field {
			fieldname := p.GetFieldName(message, field)
			p.P(`case "`, snakeCase(fieldname), `":`)
			p.In()
			p.P(`return this.`, fieldname)
			p.Out()
		}
		p.P(`}`)
		p.Out()
		p.P(`return false`)
		p.P(`}`)

		// TestFlags(...string) []error
		// Returns Flag1 AND Flag2 AND ...
		p.P(`func (this *`, baseTypeName, `) TestFlags(flags ...string) bool {`)
		p.In()
		p.P(`for _, f := range flags {`)
		p.In()
		p.P(`if !this.TestFlag(f) {`)
		p.In()
		p.P(`return false`)
		p.Out()
		p.P(`}`)
		p.Out()
		p.P(`}`)
		p.P(`return true`)
		p.Out()
		p.P(`}`)

		// Represent in the database as int64()
		// returns a bitflags uint64 representation of the structure
		p.P(`func (this *`, baseTypeName, `) Scan(i interface{}) error {`)
		p.In()
		p.P(`switch v := i.(type) {`)
		types := []string{"int", "int32", "float32", "float64"}
		for _, t := range types {
			p.P(`case `, t, `:`)
			p.In()
			p.P(`return this.FromUInt64(uint64(v))`)
			p.Out()
		}
		p.P(`}`)
		p.P()
		p.P(`return fmt.Errorf("invalid type: %T", i)`)
		p.P(`}`)

		// FromInt64(b uint64)
		// TODO(dan) should return error if overflow
		p.P(`func (this *`, baseTypeName, `) FromUInt64(b uint64) error {`)
		p.In()
		p.P(`bb := b`)
		for i, field := range message.Field {
			p.P(`bb = b`)
			fieldname := p.GetFieldName(message, field)
			p.P(`if bb&(uint64(1)<<uint(`, i, `)) > 0 {`)
			p.In()
			p.P(`this.`, fieldname, ` = true`)
			p.Out()
			p.P(`} else {`)
			p.In()
			p.P(`this.`, fieldname, ` = false`)
			p.P(`}`)
			p.Out()
		}
		p.P()
		p.P(`return nil`)
		p.P(`}`)

	}
}

func snakeCase(in string) string {
	runes := []rune(in)
	length := len(runes)
	out := bytes.NewBuffer(make([]byte, 0, length))

	for i := 0; i < length; i++ {
		if i > 0 && unicode.IsUpper(runes[i]) && ((i+1 < length && unicode.IsLower(runes[i+1])) || unicode.IsLower(runes[i-1])) {
			out.WriteRune('_')
		}
		out.WriteRune(unicode.ToLower(runes[i]))
	}

	return out.String()
}
func init() {
	generator.RegisterPlugin(NewBitflags())
}