Last active
August 29, 2015 13:59
-
-
Save bprosnitz/10744101 to your computer and use it in GitHub Desktop.
Debuggable Deep Equal
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package build | |
import ( | |
"fmt" | |
"reflect" | |
) | |
type visit struct { | |
a1 uintptr | |
a2 uintptr | |
typ reflect.Type | |
} | |
func debuggableDeepValueEqual(v1, v2 reflect.Value, visited map[visit]bool) (res bool) { | |
var msg = "" | |
defer func() { | |
if !res { | |
fmt.Printf("Not equal: %s %v %v\n", msg, v1, v2) | |
} | |
}() | |
if !v1.IsValid() || !v2.IsValid() { | |
msg = "validity" | |
res = v1.IsValid() == v2.IsValid() | |
return | |
} | |
if v1.Type() != v2.Type() { | |
msg = "type" | |
res = false | |
return | |
} | |
hard := func(k reflect.Kind) bool { | |
switch k { | |
case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: | |
return true | |
} | |
return false | |
} | |
if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) { | |
addr1 := v1.UnsafeAddr() | |
addr2 := v2.UnsafeAddr() | |
if addr1 > addr2 { | |
// Canonicalize order to reduce number of entries in visited. | |
addr1, addr2 = addr2, addr1 | |
} | |
// Short circuit if references are identical ... | |
if addr1 == addr2 { | |
res = true | |
return | |
} | |
// ... or already seen | |
typ := v1.Type() | |
v := visit{addr1, addr2, typ} | |
if visited[v] { | |
res = true | |
} | |
// Remember for later. | |
visited[v] = true | |
} | |
switch v1.Kind() { | |
case reflect.Array: | |
if v1.Len() != v2.Len() { | |
msg = "array len" | |
res = false | |
return | |
} | |
for i := 0; i < v1.Len(); i++ { | |
if !debuggableDeepValueEqual(v1.Index(i), v2.Index(i), visited) { | |
msg = "array recurse" | |
res = false | |
return | |
} | |
} | |
return true | |
case reflect.Slice: | |
if v1.IsNil() != v2.IsNil() { | |
msg = "slice nility" | |
res = false | |
return | |
} | |
if v1.Len() != v2.Len() { | |
msg = "slice len" | |
res = false | |
return | |
} | |
if v1.Pointer() == v2.Pointer() { | |
res = true | |
return | |
} | |
for i := 0; i < v1.Len(); i++ { | |
if !debuggableDeepValueEqual(v1.Index(i), v2.Index(i), visited) { | |
msg = "slice recurse" | |
res = false | |
return | |
} | |
} | |
res = true | |
return | |
case reflect.Interface: | |
if v1.IsNil() || v2.IsNil() { | |
res = v1.IsNil() == v2.IsNil() | |
msg = "interface nility" | |
return | |
} | |
res = debuggableDeepValueEqual(v1.Elem(), v2.Elem(), visited) | |
msg = "interface recurse" | |
return | |
case reflect.Ptr: | |
res = debuggableDeepValueEqual(v1.Elem(), v2.Elem(), visited) | |
msg = "ptr recurse" | |
return | |
case reflect.Struct: | |
for i, n := 0, v1.NumField(); i < n; i++ { | |
if !debuggableDeepValueEqual(v1.Field(i), v2.Field(i), visited) { | |
msg = "struct recurse field '" + v1.Type().Field(i).Name + "'" | |
res = false | |
return | |
} | |
} | |
res = true | |
return | |
case reflect.Map: | |
if v1.IsNil() != v2.IsNil() { | |
msg = "map nility" | |
res = false | |
return | |
} | |
if v1.Len() != v2.Len() { | |
msg = "map len" | |
res = false | |
return | |
} | |
if v1.Pointer() == v2.Pointer() { | |
res = true | |
return | |
} | |
for _, k := range v1.MapKeys() { | |
if !debuggableDeepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited) { | |
msg = "map recurse key '" + k.String() + "'" | |
res = false | |
return | |
} | |
} | |
res = true | |
return | |
case reflect.Func: | |
if v1.IsNil() && v2.IsNil() { | |
res = true | |
return | |
} | |
// Can't do better than this: | |
msg = "func" | |
res = false | |
return | |
case reflect.String: | |
msg = "string" | |
res = v1.String() == v2.String() | |
return | |
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |
msg = "int" | |
res = v1.Int() == v2.Int() | |
return | |
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: | |
msg = "uint" | |
res = v1.Uint() == v2.Uint() | |
return | |
case reflect.Float32, reflect.Float64: | |
msg = "float" | |
res = v1.Float() == v2.Float() | |
return | |
case reflect.Complex64, reflect.Complex128: | |
msg = "complex" | |
res = v1.Complex() == v2.Complex() | |
return | |
case reflect.Bool: | |
msg = "bool" | |
res = v1.Bool() == v2.Bool() | |
return | |
default: | |
panic(fmt.Sprintf("Unhandled kind: %v", v1.Kind())) | |
} | |
} | |
// DebuggableDeepEqual has identical behavior to go's DeepEqual but includes debugging info. | |
func DebuggableDeepEqual(a1, a2 interface{}) bool { | |
if a1 == nil || a2 == nil { | |
return a1 == a2 | |
} | |
v1 := reflect.ValueOf(a1) | |
v2 := reflect.ValueOf(a2) | |
if v1.Type() != v2.Type() { | |
return false | |
} | |
fmt.Printf("Comparing %v %v\n", a1, a2) | |
return debuggableDeepValueEqual(v1, v2, make(map[visit]bool)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment