diff --git a/types/bench_test.go b/types/bench_test.go index 6c1cbc88cb09..e4d8a07e1ca0 100644 --- a/types/bench_test.go +++ b/types/bench_test.go @@ -28,3 +28,53 @@ func BenchmarkParseCoin(b *testing.B) { } } } + +func BenchmarkUintMarshal(b *testing.B) { + var values = []uint64{ + 0, + 1, + 1 << 10, + 1<<10 - 3, + 1<<63 - 1, + 1<<32 - 7, + 1<<22 - 8, + } + + var scratch [20]byte + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, value := range values { + u := types.NewUint(value) + n, err := u.MarshalTo(scratch[:]) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(n)) + } + } +} + +func BenchmarkIntMarshal(b *testing.B) { + var values = []int64{ + 0, + 1, + 1 << 10, + 1<<10 - 3, + 1<<63 - 1, + 1<<32 - 7, + 1<<22 - 8, + } + + var scratch [20]byte + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, value := range values { + in := types.NewInt(value) + n, err := in.MarshalTo(scratch[:]) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(n)) + } + } +} diff --git a/types/int.go b/types/int.go index 81f6b1c04ddb..c33ff040a51e 100644 --- a/types/int.go +++ b/types/int.go @@ -375,7 +375,7 @@ func (i *Int) MarshalTo(data []byte) (n int, err error) { if i.i == nil { i.i = new(big.Int) } - if len(i.i.Bytes()) == 0 { + if i.i.BitLen() == 0 { // The value 0 copy(data, []byte{0x30}) return 1, nil } diff --git a/types/int_test.go b/types/int_test.go index 4b26ddb8ada5..2992f938b384 100644 --- a/types/int_test.go +++ b/types/int_test.go @@ -1,6 +1,7 @@ package types_test import ( + "fmt" "math/big" "math/rand" "strconv" @@ -385,3 +386,36 @@ func (s *intTestSuite) TestIntEq() { _, resp, _, _, _ = sdk.IntEq(s.T(), sdk.OneInt(), sdk.ZeroInt()) s.Require().False(resp) } + +func TestRoundTripMarshalToInt(t *testing.T) { + var values = []int64{ + 0, + 1, + 1 << 10, + 1<<10 - 3, + 1<<63 - 1, + 1<<32 - 7, + 1<<22 - 8, + } + + for _, value := range values { + value := value + t.Run(fmt.Sprintf("%d", value), func(t *testing.T) { + t.Parallel() + + var scratch [20]byte + iv := sdk.NewInt(value) + n, err := iv.MarshalTo(scratch[:]) + if err != nil { + t.Fatal(err) + } + rt := new(sdk.Int) + if err := rt.Unmarshal(scratch[:n]); err != nil { + t.Fatal(err) + } + if !rt.Equal(iv) { + t.Fatalf("roundtrip=%q != original=%q", rt, iv) + } + }) + } +} diff --git a/types/uint.go b/types/uint.go index 2305d7386728..d449c5262b19 100644 --- a/types/uint.go +++ b/types/uint.go @@ -161,7 +161,7 @@ func (u *Uint) MarshalTo(data []byte) (n int, err error) { if u.i == nil { u.i = new(big.Int) } - if len(u.i.Bytes()) == 0 { + if u.i.BitLen() == 0 { // The value 0 copy(data, []byte{0x30}) return 1, nil } diff --git a/types/uint_test.go b/types/uint_test.go index 36705c890a19..b91f9ab4a96a 100644 --- a/types/uint_test.go +++ b/types/uint_test.go @@ -1,6 +1,7 @@ package types_test import ( + "fmt" "math" "math/big" "math/rand" @@ -290,3 +291,36 @@ func maxuint(i1, i2 uint64) uint64 { } return i2 } + +func TestRoundTripMarshalToUint(t *testing.T) { + var values = []uint64{ + 0, + 1, + 1 << 10, + 1<<10 - 3, + 1<<63 - 1, + 1<<32 - 7, + 1<<22 - 8, + } + + for _, value := range values { + value := value + t.Run(fmt.Sprintf("%d", value), func(t *testing.T) { + t.Parallel() + + var scratch [20]byte + uv := sdk.NewUint(value) + n, err := uv.MarshalTo(scratch[:]) + if err != nil { + t.Fatal(err) + } + rt := new(sdk.Uint) + if err := rt.Unmarshal(scratch[:n]); err != nil { + t.Fatal(err) + } + if !rt.Equal(uv) { + t.Fatalf("roundtrip=%q != original=%q", rt, uv) + } + }) + } +}