Created
June 28, 2022 10:26
-
-
Save ulexxander/bf20ef508c2177ccef0de5b5613fa4ee to your computer and use it in GitHub Desktop.
Go decimal.Decimal equality tests assertion workaround - UnscaleDecimals.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package requirex | |
import ( | |
"testing" | |
"github.com/stretchr/testify/require" | |
) | |
// Equal wraps require.Equal and adds unscaling of decimal.Decimal values. | |
// It prevents tests from failing if expected and actual decimal.Decimal values | |
// have same decimal value, but different internal value and exponent. | |
// It requires indirect value to be passed in (pointer, slice, map). | |
func Equal(t *testing.T, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { | |
t.Helper() | |
UnscaleDecimals(expected) | |
UnscaleDecimals(actual) | |
require.Equal(t, expected, actual, msgAndArgs...) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package requirex | |
import ( | |
"reflect" | |
"github.com/shopspring/decimal" | |
) | |
var decimalType = reflect.TypeOf(decimal.Decimal{}) | |
// UnscaleDecimals accepts decimals, struct pointers or maps and slices of those. | |
// It will traverse all values and modify decimal.Decimal ones | |
// so all of them will have default scale. | |
func UnscaleDecimals(i interface{}) { | |
v := reflect.ValueOf(i) | |
switch v.Kind() { | |
case reflect.Ptr, reflect.Map, reflect.Slice: | |
unscaleDecimals(v) | |
default: | |
panic("unsupported decimals unscaling value kind: " + v.Kind().String()) | |
} | |
} | |
func unscaleDecimals(v reflect.Value) { | |
if v.Kind() == reflect.Ptr { | |
v = v.Elem() | |
} | |
switch v.Kind() { | |
case reflect.Struct: | |
if v.Type() == decimalType && v.CanSet() { | |
dec := v.Interface().(decimal.Decimal) | |
unscaled := decimal.RequireFromString(dec.String()) | |
v.Set(reflect.ValueOf(unscaled)) | |
} | |
for i := 0; i < v.NumField(); i++ { | |
field := v.Field(i) | |
if field.CanSet() { | |
unscaleDecimals(field.Addr()) | |
} | |
} | |
case reflect.Slice: | |
for i := 0; i < v.Len(); i++ { | |
elem := v.Index(i) | |
unscaleDecimals(elem.Addr()) | |
} | |
case reflect.Map: | |
for _, k := range v.MapKeys() { | |
elem := v.MapIndex(k) | |
elemCopy := reflect.New(elem.Type()).Elem() | |
elemCopy.Set(elem) | |
unscaleDecimals(elemCopy) | |
v.SetMapIndex(k, elemCopy) | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package requirex_test | |
import ( | |
"testing" | |
"github.com/shopspring/decimal" | |
"github.com/stretchr/testify/require" | |
"gitlab.dev.limitlex.io/external-exchanges/market-maker-bot/testutil/requirex" | |
) | |
type order struct { | |
ID string | |
Price decimal.Decimal | |
Amount int | |
unexported decimal.Decimal | |
Fee fee | |
Trades []trade | |
} | |
type fee struct { | |
Amount decimal.Decimal | |
Currency string | |
} | |
type trade struct { | |
ID string | |
Amount decimal.Decimal | |
} | |
func TestUnscaleDecimals(t *testing.T) { | |
t.Run("decimal", func(t *testing.T) { | |
dec := decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5)) | |
requirex.UnscaleDecimals(&dec) | |
require.Equal(t, decimal.NewFromInt(500), dec) | |
}) | |
t.Run("struct", func(t *testing.T) { | |
o := order{ | |
ID: "123", | |
// Calling Mul wil change the scale of decimal. | |
Price: decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5)), | |
Amount: 10, | |
// Unexported fields will not be touched. | |
unexported: decimal.NewFromInt(123).Mul(decimal.NewFromFloat(10)), | |
// Nested struct. | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.2).Div(decimal.NewFromInt(100)), | |
Currency: "USDT", | |
}, | |
// Slice inside struct. | |
Trades: []trade{ | |
{ | |
ID: "999", | |
Amount: decimal.NewFromFloat(10).Div(decimal.NewFromInt(8)), | |
}, | |
}, | |
} | |
requirex.UnscaleDecimals(&o) | |
require.Equal(t, order{ | |
ID: "123", | |
Price: decimal.NewFromInt(500), | |
Amount: 10, | |
unexported: decimal.NewFromInt(123).Mul(decimal.NewFromFloat(10)), | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.002), | |
Currency: "USDT", | |
}, | |
Trades: []trade{ | |
{ | |
ID: "999", | |
Amount: decimal.NewFromFloat(1.25), | |
}, | |
}, | |
}, o) | |
}) | |
t.Run("slice of structs", func(t *testing.T) { | |
s := []order{ | |
{ | |
Price: decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5)), | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.1).Div(decimal.NewFromInt(100)), | |
}, | |
}, | |
{ | |
Price: decimal.NewFromInt(2000).Mul(decimal.NewFromFloat(0.5)), | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.2).Div(decimal.NewFromInt(100)), | |
}, | |
}, | |
} | |
requirex.UnscaleDecimals(&s) | |
require.Equal(t, []order{ | |
{ | |
Price: decimal.NewFromInt(500), | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.001), | |
}, | |
}, | |
{ | |
Price: decimal.NewFromInt(1000), | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.002), | |
}, | |
}, | |
}, s) | |
}) | |
t.Run("map of structs", func(t *testing.T) { | |
m := map[string]order{ | |
"123": { | |
Price: decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5)), | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.1).Div(decimal.NewFromInt(100)), | |
}, | |
}, | |
"456": { | |
Price: decimal.NewFromInt(2000).Mul(decimal.NewFromFloat(0.5)), | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.2).Div(decimal.NewFromInt(100)), | |
}, | |
}, | |
} | |
requirex.UnscaleDecimals(&m) | |
require.Equal(t, map[string]order{ | |
"123": { | |
Price: decimal.NewFromInt(500), | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.001), | |
}, | |
}, | |
"456": { | |
Price: decimal.NewFromInt(1000), | |
Fee: fee{ | |
Amount: decimal.NewFromFloat(0.002), | |
}, | |
}, | |
}, m) | |
}) | |
t.Run("not addressable decimal", func(t *testing.T) { | |
require.Panics(t, func() { | |
original := decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5)) | |
dec := original.Add(decimal.Zero) | |
requirex.UnscaleDecimals(dec) | |
require.Equal(t, original, dec) | |
}) | |
}) | |
t.Run("not addressable struct", func(t *testing.T) { | |
require.Panics(t, func() { | |
originalPrice := decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5)) | |
o := order{ | |
ID: "123", | |
Price: originalPrice.Add(decimal.Zero), | |
} | |
requirex.UnscaleDecimals(o) | |
require.Equal(t, order{ | |
ID: "123", | |
Price: originalPrice, | |
}, o) | |
}) | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment