Created
December 3, 2013 11:34
-
-
Save jaekwon/7767747 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type RowScanner interface { | |
Scan(dest ...interface{}) error | |
} | |
type ModelInfo struct { | |
Type reflect.Type | |
TableName string | |
Fields []*reflect.StructField | |
FieldsSimple string | |
FieldsPrefixed string | |
Placeholders string | |
} | |
var allModelInfos = map[string]*ModelInfo{} | |
func (m *ModelInfo) FieldValues(i interface{}) []interface{} { | |
v := reflect.ValueOf(i) | |
if v.Kind() == reflect.Ptr { | |
v = v.Elem() | |
} | |
if v.Type() != m.Type { | |
log.Panicf("Invalid argument for FieldValues: Type mismatch. Expected %v but got %v", | |
v.Type(), m.Type) | |
} | |
fvs := []interface{}{} | |
for _, field := range m.Fields { | |
name := field.Name | |
fieldValue := v.FieldByName(name) | |
fvs = append(fvs, fieldValue.Interface()) | |
} | |
return fvs | |
} | |
func GetModelInfo(i interface{}) *ModelInfo { | |
t := reflect.TypeOf(i) | |
return GetModelInfoFromType(t) | |
} | |
func GetModelInfoFromType(modelType reflect.Type) *ModelInfo { | |
if modelType.Kind() == reflect.Ptr { | |
modelType = modelType.Elem() | |
} | |
if modelType.Kind() != reflect.Struct { | |
return nil | |
} | |
modelName := modelType.Name() | |
// Check cache | |
if allModelInfos[modelName] != nil { | |
return allModelInfos[modelName] | |
} | |
// Construct | |
m := &ModelInfo{} | |
allModelInfos[modelName] = m | |
m.Type = modelType | |
m.TableName = strings.ToLower(modelName) | |
// Fields | |
numFields := m.Type.NumField() | |
for i:=0; i<numFields; i++ { | |
field := m.Type.Field(i) | |
if field.Tag.Get("db") != "" { | |
m.Fields = append(m.Fields, &field) | |
} | |
} | |
// Simple & Prefixed | |
fieldNames := []string{} | |
ph := []string{} | |
for _, field := range m.Fields { | |
fieldNames = append(fieldNames, field.Tag.Get("db")) | |
ph = append(ph, "?") | |
} | |
m.FieldsSimple = strings.Join(fieldNames, ", ") | |
m.FieldsPrefixed = m.TableName+"."+strings.Join(fieldNames, ", "+m.TableName+".") | |
m.Placeholders = strings.Join(ph, ", ") | |
return m | |
} | |
func expandArgs(args []interface{}) []interface{} { | |
a := []interface{}{} | |
for _, arg := range args { | |
modelInfo := GetModelInfo(arg) | |
if modelInfo == nil { | |
a = append(a, arg) | |
} else { | |
a = append(a, modelInfo.FieldValues(arg)...) | |
} | |
} | |
return a | |
} | |
func Exec(query string, args ...interface{}) (sql.Result, error) { | |
return GetDB().Exec(query, expandArgs(args)...) | |
} | |
func QueryRow(query string, args ...interface{}) RowScanner { | |
return &StructScanner{GetDB().QueryRow(query, expandArgs(args)...)} | |
} | |
func Query(query string, args ...interface{}) (RowScanner, error) { | |
rows, err := GetDB().Query(query, expandArgs(args)...) | |
if err != nil { return nil, err } | |
return &StructScanner{rows}, nil | |
} | |
type StructScanner struct { | |
Scanner RowScanner | |
} | |
func (s *StructScanner) Scan(dest ...interface{}) error { | |
destValuesP := []interface{}{} | |
for _, d := range dest { | |
dValueP := reflect.ValueOf(d) | |
dValue := dValueP.Elem() | |
if dValue.Kind() != reflect.Struct { | |
destValuesP = append(destValuesP, dValueP.Addr().Interface()) | |
} else { | |
m := GetModelInfoFromType(dValue.Type()) | |
for _, field := range m.Fields { | |
dField := dValue.FieldByName(field.Name) | |
destValuesP = append(destValuesP, dField.Addr().Interface()) | |
} | |
} | |
} | |
return s.Scanner.Scan(destValuesP...) | |
} | |
////////////// USAGE | |
type User struct { | |
Id string `db:"id"` | |
Email string `db:"email"` | |
} | |
var UserModel = GetModelIfno(new(User)) | |
func test() { | |
// inserting a struct | |
db.Exec(`INSERT INTO user(`+UserModel.FieldsSimple+`) VALUES (`+UserModel.Placeholders+`)`, user) | |
// loading a struct | |
var user User | |
db.QueryRow(`SELECT `+UserModel.FieldsSimple+` FROM user WHERE email=?`, email).Scan(&user) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment