Skip to content

Instantly share code, notes, and snippets.

@ulexxander
Created June 28, 2022 10:26
Show Gist options
  • Save ulexxander/bf20ef508c2177ccef0de5b5613fa4ee to your computer and use it in GitHub Desktop.
Save ulexxander/bf20ef508c2177ccef0de5b5613fa4ee to your computer and use it in GitHub Desktop.
Go decimal.Decimal equality tests assertion workaround - UnscaleDecimals.
package requirex
import (
"testing"
"github.com/stretchr/testify/require"
)
// Equal wraps require.Equal and adds unscaling of decimal.Decimal values.
// It prevents tests from failing if expected and actual decimal.Decimal values
// have same decimal value, but different internal value and exponent.
// It requires indirect value to be passed in (pointer, slice, map).
func Equal(t *testing.T, expected interface{}, actual interface{}, msgAndArgs ...interface{}) {
t.Helper()
UnscaleDecimals(expected)
UnscaleDecimals(actual)
require.Equal(t, expected, actual, msgAndArgs...)
}
package requirex
import (
"reflect"
"github.com/shopspring/decimal"
)
var decimalType = reflect.TypeOf(decimal.Decimal{})
// UnscaleDecimals accepts decimals, struct pointers or maps and slices of those.
// It will traverse all values and modify decimal.Decimal ones
// so all of them will have default scale.
func UnscaleDecimals(i interface{}) {
v := reflect.ValueOf(i)
switch v.Kind() {
case reflect.Ptr, reflect.Map, reflect.Slice:
unscaleDecimals(v)
default:
panic("unsupported decimals unscaling value kind: " + v.Kind().String())
}
}
func unscaleDecimals(v reflect.Value) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
if v.Type() == decimalType && v.CanSet() {
dec := v.Interface().(decimal.Decimal)
unscaled := decimal.RequireFromString(dec.String())
v.Set(reflect.ValueOf(unscaled))
}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.CanSet() {
unscaleDecimals(field.Addr())
}
}
case reflect.Slice:
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
unscaleDecimals(elem.Addr())
}
case reflect.Map:
for _, k := range v.MapKeys() {
elem := v.MapIndex(k)
elemCopy := reflect.New(elem.Type()).Elem()
elemCopy.Set(elem)
unscaleDecimals(elemCopy)
v.SetMapIndex(k, elemCopy)
}
}
}
package requirex_test
import (
"testing"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"gitlab.dev.limitlex.io/external-exchanges/market-maker-bot/testutil/requirex"
)
type order struct {
ID string
Price decimal.Decimal
Amount int
unexported decimal.Decimal
Fee fee
Trades []trade
}
type fee struct {
Amount decimal.Decimal
Currency string
}
type trade struct {
ID string
Amount decimal.Decimal
}
func TestUnscaleDecimals(t *testing.T) {
t.Run("decimal", func(t *testing.T) {
dec := decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5))
requirex.UnscaleDecimals(&dec)
require.Equal(t, decimal.NewFromInt(500), dec)
})
t.Run("struct", func(t *testing.T) {
o := order{
ID: "123",
// Calling Mul wil change the scale of decimal.
Price: decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5)),
Amount: 10,
// Unexported fields will not be touched.
unexported: decimal.NewFromInt(123).Mul(decimal.NewFromFloat(10)),
// Nested struct.
Fee: fee{
Amount: decimal.NewFromFloat(0.2).Div(decimal.NewFromInt(100)),
Currency: "USDT",
},
// Slice inside struct.
Trades: []trade{
{
ID: "999",
Amount: decimal.NewFromFloat(10).Div(decimal.NewFromInt(8)),
},
},
}
requirex.UnscaleDecimals(&o)
require.Equal(t, order{
ID: "123",
Price: decimal.NewFromInt(500),
Amount: 10,
unexported: decimal.NewFromInt(123).Mul(decimal.NewFromFloat(10)),
Fee: fee{
Amount: decimal.NewFromFloat(0.002),
Currency: "USDT",
},
Trades: []trade{
{
ID: "999",
Amount: decimal.NewFromFloat(1.25),
},
},
}, o)
})
t.Run("slice of structs", func(t *testing.T) {
s := []order{
{
Price: decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5)),
Fee: fee{
Amount: decimal.NewFromFloat(0.1).Div(decimal.NewFromInt(100)),
},
},
{
Price: decimal.NewFromInt(2000).Mul(decimal.NewFromFloat(0.5)),
Fee: fee{
Amount: decimal.NewFromFloat(0.2).Div(decimal.NewFromInt(100)),
},
},
}
requirex.UnscaleDecimals(&s)
require.Equal(t, []order{
{
Price: decimal.NewFromInt(500),
Fee: fee{
Amount: decimal.NewFromFloat(0.001),
},
},
{
Price: decimal.NewFromInt(1000),
Fee: fee{
Amount: decimal.NewFromFloat(0.002),
},
},
}, s)
})
t.Run("map of structs", func(t *testing.T) {
m := map[string]order{
"123": {
Price: decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5)),
Fee: fee{
Amount: decimal.NewFromFloat(0.1).Div(decimal.NewFromInt(100)),
},
},
"456": {
Price: decimal.NewFromInt(2000).Mul(decimal.NewFromFloat(0.5)),
Fee: fee{
Amount: decimal.NewFromFloat(0.2).Div(decimal.NewFromInt(100)),
},
},
}
requirex.UnscaleDecimals(&m)
require.Equal(t, map[string]order{
"123": {
Price: decimal.NewFromInt(500),
Fee: fee{
Amount: decimal.NewFromFloat(0.001),
},
},
"456": {
Price: decimal.NewFromInt(1000),
Fee: fee{
Amount: decimal.NewFromFloat(0.002),
},
},
}, m)
})
t.Run("not addressable decimal", func(t *testing.T) {
require.Panics(t, func() {
original := decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5))
dec := original.Add(decimal.Zero)
requirex.UnscaleDecimals(dec)
require.Equal(t, original, dec)
})
})
t.Run("not addressable struct", func(t *testing.T) {
require.Panics(t, func() {
originalPrice := decimal.NewFromInt(1000).Mul(decimal.NewFromFloat(0.5))
o := order{
ID: "123",
Price: originalPrice.Add(decimal.Zero),
}
requirex.UnscaleDecimals(o)
require.Equal(t, order{
ID: "123",
Price: originalPrice,
}, o)
})
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment