From 3fc6240c949511dc6b17080731a949c18477e160 Mon Sep 17 00:00:00 2001 From: Sunny Aggarwal Date: Sat, 4 Jan 2020 15:16:12 -0500 Subject: [PATCH] Merge PR #5447: Added nth root function to sdk.Decimal type --- CHANGELOG.md | 2 ++ types/decimal.go | 75 ++++++++++++++++++++++++++++++++++--------- types/decimal_test.go | 45 +++++++++++++++++++++++++- types/int.go | 21 ++++++++++++ types/int_test.go | 8 +++++ 5 files changed, 134 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0dea7da2b2bc..b489664dedd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -174,6 +174,8 @@ that allows for arbitrary vesting periods. * Introduces cli commands and rest routes to query historical information at a given height * (modules) [\#5249](https://github.com/cosmos/cosmos-sdk/pull/5249) Funds are now allowed to be directly sent to the community pool (via the distribution module account). * (keys) [\#4941](https://github.com/cosmos/cosmos-sdk/issues/4941) Introduce keybase option to allow overriding the default private key implementation of a key generated through the `keys add` cli command. +* (types) [\#5447](https://github.com/cosmos/cosmos-sdk/pull/5447) Added `ApproxRoot` function to sdk.Decimal type in order to get the nth root for a decimal number, where n is a positive integer. + * An `ApproxSqrt` function was also added for convenience around the common case of n=2. ### Improvements diff --git a/types/decimal.go b/types/decimal.go index 63aecf6cab18..839523080656 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -269,7 +269,6 @@ func (d Dec) MulInt64(i int64) Dec { // quotient func (d Dec) Quo(d2 Dec) Dec { - // multiply precision twice mul := new(big.Int).Mul(d.Int, precisionReuse) mul.Mul(mul, precisionReuse) @@ -326,30 +325,74 @@ func (d Dec) QuoInt64(i int64) Dec { return Dec{mul} } -// ApproxSqrt returns an approximate sqrt estimation using Newton's method to -// compute square roots x=√d for d > 0. The algorithm starts with some guess and +// ApproxRoot returns an approximate estimation of a Dec's positive real nth root +// using Newton's method (where n is positive). The algorithm starts with some guess and // computes the sequence of improved guesses until an answer converges to an -// approximate answer. It returns -(sqrt(abs(d)) if input is negative. -func (d Dec) ApproxSqrt() Dec { +// approximate answer. It returns `|d|.ApproxRoot() * -1` if input is negative. +func (d Dec) ApproxRoot(root uint64) (guess Dec, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + err = errors.New("out of bounds") + } + } + }() + if d.IsNegative() { - return d.MulInt64(-1).ApproxSqrt().MulInt64(-1) + absRoot, err := d.MulInt64(-1).ApproxRoot(root) + return absRoot.MulInt64(-1), err + } + + if root == 1 || d.IsZero() || d.Equal(OneDec()) { + return d, nil + } + + if root == 0 { + return OneDec(), nil } - if d.IsZero() { - return ZeroDec() + rootInt := NewIntFromUint64(root) + guess, delta := OneDec(), OneDec() + + for delta.Abs().GT(SmallestDec()) { + prev := guess.Power(root - 1) + if prev.IsZero() { + prev = SmallestDec() + } + delta = d.Quo(prev) + delta = delta.Sub(guess) + delta = delta.QuoInt(rootInt) + + guess = guess.Add(delta) } - z := OneDec() - // first guess - z = z.Sub((z.Mul(z).Sub(d)).Quo(z.MulInt64(2))) + return guess, nil +} - // iterate until change is very small - for zNew, delta := z, z; delta.GT(SmallestDec()); z = zNew { - zNew = zNew.Sub((zNew.Mul(zNew).Sub(d)).Quo(zNew.MulInt64(2))) - delta = z.Sub(zNew) +// Power returns a the result of raising to a positive integer power +func (d Dec) Power(power uint64) Dec { + if power == 0 { + return OneDec() } + tmp := OneDec() + for i := power; i > 1; { + if i%2 == 0 { + i /= 2 + } else { + tmp = tmp.Mul(d) + i = (i - 1) / 2 + } + d = d.Mul(d) + } + return d.Mul(tmp) +} - return z +// ApproxSqrt is a wrapper around ApproxRoot for the common special case +// of finding the square root of a number. It returns -(sqrt(abs(d)) if input is negative. +func (d Dec) ApproxSqrt() (Dec, error) { + return d.ApproxRoot(2) } // is integer, e.g. decimals are zero diff --git a/types/decimal_test.go b/types/decimal_test.go index ea97a1029a5a..48c691a100dc 100644 --- a/types/decimal_test.go +++ b/types/decimal_test.go @@ -425,6 +425,48 @@ func TestDecCeil(t *testing.T) { } } +func TestPower(t *testing.T) { + testCases := []struct { + input Dec + power uint64 + expected Dec + }{ + {OneDec(), 10, OneDec()}, // 1.0 ^ (10) => 1.0 + {NewDecWithPrec(5, 1), 2, NewDecWithPrec(25, 2)}, // 0.5 ^ 2 => 0.25 + {NewDecWithPrec(2, 1), 2, NewDecWithPrec(4, 2)}, // 0.2 ^ 2 => 0.04 + {NewDecFromInt(NewInt(3)), 3, NewDecFromInt(NewInt(27))}, // 3 ^ 3 => 27 + {NewDecFromInt(NewInt(-3)), 4, NewDecFromInt(NewInt(81))}, // -3 ^ 4 = 81 + {NewDecWithPrec(1414213562373095049, 18), 2, NewDecFromInt(NewInt(2))}, // 1.414213562373095049 ^ 2 = 2 + } + + for i, tc := range testCases { + res := tc.input.Power(tc.power) + require.True(t, tc.expected.Sub(res).Abs().LTE(SmallestDec()), "unexpected result for test case %d, input: %v", i, tc.input) + } +} + +func TestApproxRoot(t *testing.T) { + testCases := []struct { + input Dec + root uint64 + expected Dec + }{ + {OneDec(), 10, OneDec()}, // 1.0 ^ (0.1) => 1.0 + {NewDecWithPrec(25, 2), 2, NewDecWithPrec(5, 1)}, // 0.25 ^ (0.5) => 0.5 + {NewDecWithPrec(4, 2), 2, NewDecWithPrec(2, 1)}, // 0.04 => 0.2 + {NewDecFromInt(NewInt(27)), 3, NewDecFromInt(NewInt(3))}, // 27 ^ (1/3) => 3 + {NewDecFromInt(NewInt(-81)), 4, NewDecFromInt(NewInt(-3))}, // -81 ^ (0.25) => -3 + {NewDecFromInt(NewInt(2)), 2, NewDecWithPrec(1414213562373095049, 18)}, // 2 ^ (0.5) => 1.414213562373095049 + {NewDecWithPrec(1005, 3), 31536000, MustNewDecFromStr("1.000000000158153904")}, + } + + for i, tc := range testCases { + res, err := tc.input.ApproxRoot(tc.root) + require.NoError(t, err) + require.True(t, tc.expected.Sub(res).Abs().LTE(SmallestDec()), "unexpected result for test case %d, input: %v", i, tc.input) + } +} + func TestApproxSqrt(t *testing.T) { testCases := []struct { input Dec @@ -439,7 +481,8 @@ func TestApproxSqrt(t *testing.T) { } for i, tc := range testCases { - res := tc.input.ApproxSqrt() + res, err := tc.input.ApproxSqrt() + require.NoError(t, err) require.Equal(t, tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input) } } diff --git a/types/int.go b/types/int.go index 41a6efef74c7..caabdb11318c 100644 --- a/types/int.go +++ b/types/int.go @@ -114,6 +114,13 @@ func NewInt(n int64) Int { return Int{big.NewInt(n)} } +// NewIntFromUint64 constructs an Int from a uint64. +func NewIntFromUint64(n uint64) Int { + b := big.NewInt(0) + b.SetUint64(n) + return Int{b} +} + // NewIntFromBigInt constructs Int from big.Int func NewIntFromBigInt(i *big.Int) Int { if i.BitLen() > maxBitLen { @@ -178,6 +185,20 @@ func (i Int) IsInt64() bool { return i.i.IsInt64() } +// Uint64 converts Int to uint64 +// Panics if the value is out of range +func (i Int) Uint64() uint64 { + if !i.i.IsUint64() { + panic("Uint64() out of bounds") + } + return i.i.Uint64() +} + +// IsUint64 returns true if Uint64() not panics +func (i Int) IsUint64() bool { + return i.i.IsUint64() +} + // IsZero returns true if Int is zero func (i Int) IsZero() bool { return i.i.Sign() == 0 diff --git a/types/int_test.go b/types/int_test.go index f6dbc1b407e9..072b2f47bb9a 100644 --- a/types/int_test.go +++ b/types/int_test.go @@ -16,6 +16,14 @@ func TestFromInt64(t *testing.T) { } } +func TestFromUint64(t *testing.T) { + for n := 0; n < 20; n++ { + r := rand.Uint64() + require.True(t, NewIntFromUint64(r).IsUint64()) + require.Equal(t, r, NewIntFromUint64(r).Uint64()) + } +} + func TestIntPanic(t *testing.T) { // Max Int = 2^255-1 = 5.789e+76 // Min Int = -(2^255-1) = -5.789e+76