Last active
September 20, 2018 12:36
-
-
Save hysios/7b5db0815ce094b9f1b194808befdf71 to your computer and use it in GitHub Desktop.
YAML 做定制解析功能
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"encoding/json" | |
"errors" | |
"fmt" | |
"log" | |
"os" | |
"reflect" | |
"strconv" | |
"github.com/ghodss/yaml" | |
"github.com/kr/pretty" | |
) | |
type Config struct { | |
Train Train | |
} | |
type Train struct { | |
Framework DefaultFramework | |
} | |
type DefaultFramework struct { | |
Any | |
String string | |
Raw Framework | |
} | |
type Framework struct { | |
Lang string | |
Name string | |
Version string | |
Sense string | |
} | |
// Any 可以做二选一的类型, 如: string 与 struct 的转换 | |
type Any struct { | |
contains []reflect.Value | |
} | |
func AnyOf(objs ...interface{}) Any { | |
var a Any | |
for _, obj := range objs { | |
a.contains = append(a.contains, reflect.ValueOf(obj)) | |
} | |
return a | |
} | |
func bindFramework(boundTo *DefaultFramework) { | |
boundTo.Any = AnyOf(&boundTo.String, &boundTo.Raw) | |
} | |
func (a *Any) UnmarshalJSON(b []byte) error { | |
fromAnon := reflect.TypeOf(a) | |
fmt.Printf("PkgPath: %s\n", fromAnon.PkgPath()) | |
fmt.Printf("anonymous: %v", fromAnon.Anonymous) | |
if scanString(b) { | |
if v, ok := a.findType(reflect.TypeOf("")); !ok { | |
return errors.New("don't have string Type") | |
} else { | |
var val string | |
json.Unmarshal(b, &val) | |
v.SetString(val) | |
} | |
} else if scanInt(b) { | |
if v, ok := a.findType(reflect.TypeOf(0)); !ok { | |
return errors.New("don't have int Type") | |
} else { | |
i, _ := strconv.Atoi(string(b)) | |
v.SetInt(int64(i)) | |
} | |
} else if scanObject(b) { | |
if v, ok := a.findStruct(); !ok { | |
return errors.New("don't have struct Type") | |
} else { | |
var ( | |
t = v.Elem().Type() | |
val = reflect.New(t) | |
) | |
if err := json.Unmarshal(b, val.Interface()); err != nil { | |
log.Printf("json.Unmarshal error %s", err) | |
} | |
v.Elem().Set(reflect.ValueOf(val.Elem().Interface())) | |
} | |
} | |
return nil | |
} | |
func (a *Any) findType(t reflect.Type) (reflect.Value, bool) { | |
for _, v := range a.contains { | |
if v.Kind() == reflect.Ptr { | |
v = v.Elem() | |
} | |
if v.Type() == t { | |
return v, true | |
} | |
} | |
return reflect.Value{}, false | |
} | |
func (a *Any) findStruct() (reflect.Value, bool) { | |
for _, v := range a.contains { | |
if reflect.Indirect(v).Kind() == reflect.Struct { | |
return v, true | |
} | |
} | |
return reflect.Value{}, false | |
} | |
func yaml2json(data []byte) []byte { | |
j2, err := yaml.YAMLToJSON(data) | |
if err != nil { | |
log.Fatalf("error: %v", err) | |
} | |
return j2 | |
} | |
func initializeStruct(t reflect.Type, v reflect.Value) { | |
for i := 0; i < v.NumField(); i++ { | |
f := v.Field(i) | |
ft := t.Field(i) | |
switch ft.Type.Kind() { | |
case reflect.Map: | |
f.Set(reflect.MakeMap(ft.Type)) | |
case reflect.Slice: | |
f.Set(reflect.MakeSlice(ft.Type, 0, 0)) | |
case reflect.Chan: | |
f.Set(reflect.MakeChan(ft.Type, 0)) | |
case reflect.Struct: | |
initializeStruct(ft.Type, f) | |
case reflect.Ptr: | |
fv := reflect.New(ft.Type.Elem()) | |
initializeStruct(ft.Type.Elem(), fv.Elem()) | |
f.Set(fv) | |
default: | |
} | |
} | |
} | |
var ( | |
data1 = []byte(`train: | |
framework: 1.8.0 | |
`) | |
data2 = []byte(`train: | |
framework: | |
lang: python | |
name: tensorflow | |
version: 1.8.0 | |
sense: gpu | |
`) | |
) | |
func scanString(b []byte) bool { | |
l := len(b) | |
if b[0] == '"' && b[l-1] == '"' { | |
return true | |
} else { | |
return false | |
} | |
} | |
func scanObject(b []byte) bool { | |
l := len(b) | |
if b[0] == '{' && b[l-1] == '}' { | |
return true | |
} else { | |
return false | |
} | |
} | |
func scanInt(b []byte) bool { | |
// l := len(b) | |
if b[0] >= '0' && b[0] <= '9' { | |
return true | |
} else if b[0] == '-' || b[0] == '+' || b[0] == '.' { | |
return true | |
} else { | |
return false | |
} | |
} | |
func scanBool(b []byte) bool { | |
if string(b) == "false" && string(b) == "true" { | |
return true | |
} else { | |
return false | |
} | |
} | |
func build() *Config { | |
var cfg Config | |
bindFramework(&cfg.Train.Framework) | |
return &cfg | |
} | |
func main() { | |
var cfg *Config = build() | |
data := yaml2json(data1) | |
var out bytes.Buffer | |
json.Indent(&out, data, "", " ") | |
out.WriteTo(os.Stdout) | |
println() | |
if err := json.Unmarshal(data, &cfg); err != nil { | |
log.Fatalln(err) | |
} | |
log.Printf("String: %# v\n", pretty.Formatter(cfg.Train.Framework.String)) | |
cfg = build() | |
data = yaml2json(data2) | |
out.Reset() | |
json.Indent(&out, data, "", " ") | |
out.WriteTo(os.Stdout) | |
println() | |
if err := json.Unmarshal(data, &cfg); err != nil { | |
log.Fatalln(err) | |
} | |
log.Printf("framework: %# v\n", pretty.Formatter(cfg.Train.Framework.Raw)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment