Skip to content

Commit

Permalink
Fix invmod giving wrong results for moduli close to typemax (JuliaLan…
Browse files Browse the repository at this point in the history
…g#30515)

Fixes JuliaLang#29971

Co-authored-by: Simon Byrne <[email protected]>
Co-authored-by: sam0410 <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2020
1 parent 665279a commit 9e993be
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 16 deletions.
42 changes: 26 additions & 16 deletions base/intfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,21 @@ lcm(abc::AbstractArray{<:Real}) = reduce(lcm, abc; init=one(eltype(abc)))
function gcd(abc::AbstractArray{<:Integer})
a = zero(eltype(abc))
for b in abc
a = gcd(a,b)
a = gcd(a, b)
if a == 1
return a
end
end
return a
end

# return (gcd(a,b),x,y) such that ax+by == gcd(a,b)
# return (gcd(a, b), x, y) such that ax+by == gcd(a, b)
"""
gcdx(x,y)
gcdx(a, b)
Computes the greatest common (positive) divisor of `x` and `y` and their Bézout
Computes the greatest common (positive) divisor of `a` and `b` and their Bézout
coefficients, i.e. the integer coefficients `u` and `v` that satisfy
``ux+vy = d = gcd(x,y)``. ``gcdx(x,y)`` returns ``(d,u,v)``.
``ua+vb = d = gcd(a, b)``. ``gcdx(a, b)`` returns ``(d, u, v)``.
The arguments may be integer and rational numbers.
Expand All @@ -175,8 +175,8 @@ julia> gcdx(240, 46)
their `typemax`, and the identity then holds only via the unsigned
integers' modulo arithmetic.
"""
function gcdx(a::U, b::V) where {U<:Integer, V<:Integer}
T = promote_type(U, V)
function gcdx(a::Integer, b::Integer)
T = promote_type(typeof(a), typeof(b))
# a0, b0 = a, b
s0, s1 = oneunit(T), zero(T)
t0, t1 = s1, s0
Expand All @@ -197,11 +197,11 @@ gcdx(a::T, b::T) where T<:Real = throw(MethodError(gcdx, (a,b)))
# multiplicative inverse of n mod m, error if none

"""
invmod(x,m)
invmod(n, m)
Take the inverse of `x` modulo `m`: `y` such that ``x y = 1 \\pmod m``,
with ``div(x,y) = 0``. This is undefined for ``m = 0``, or if
``gcd(x,m) \\neq 1``.
Take the inverse of `n` modulo `m`: `y` such that ``n y = 1 \\pmod m``,
and ``div(y,m) = 0``. This will throw an error if ``m = 0``, or if
``gcd(n,m) \\neq 1``.
# Examples
```jldoctest
Expand All @@ -216,14 +216,24 @@ julia> invmod(5,6)
```
"""
function invmod(n::Integer, m::Integer)
iszero(m) && throw(DomainError(m, "`m` must not be 0."))
if n isa Signed
# work around inconsistencies in gcdx
# https://github.com/JuliaLang/julia/issues/33781
T = promote_type(typeof(n), typeof(m))
n == typemin(typeof(n)) && m == typeof(n)(-1) && return T(0)
n == typeof(n)(-1) && m == typemin(typeof(n)) && return T(-1)
end
g, x, y = gcdx(n, m)
g != 1 && throw(DomainError((n, m), "Greatest common divisor is $g."))
m == 0 && throw(DomainError(m, "`m` must not be 0."))
# Note that m might be negative here.
# For unsigned T, x might be close to typemax; add m to force a wrap-around.
r = mod(x + m, m)
# The postcondition is: mod(r * n, m) == mod(T(1), m) && div(r, m) == 0
r
if n isa Unsigned && hastypemax(typeof(n)) && x > typemax(n)>>1
# x might have wrapped if it would have been negative
# adding back m forces a correction
x += m
end
# The postcondition is: mod(result * n, m) == mod(T(1), m) && div(result, m) == 0
return mod(x, m)
end

# ^ for any x supporting *
Expand Down
32 changes: 32 additions & 0 deletions test/intfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,38 @@ end
@test invmod(2, 0x3) == 2
@test invmod(0x8, -3) == -1
@test_throws DomainError invmod(0, 3)

# For issue 29971
@test invmod(UInt8(1), typemax(UInt8)) == invmod(UInt16(1), UInt16(typemax(UInt8)))
@test invmod(UInt16(1), typemax(UInt16)) == invmod(UInt32(1), UInt32(typemax(UInt16)))
@test invmod(UInt32(1), typemax(UInt32)) == invmod(UInt64(1), UInt64(typemax(UInt32)))
@test invmod(UInt64(1), typemax(UInt64)) == invmod(UInt128(1), UInt128(typemax(UInt64)))

@test invmod(UInt8(3), UInt8(124)) == invmod(3, 124)
@test invmod(UInt16(3), UInt16(124)) == invmod(3, 124)
@test invmod(UInt32(3), UInt32(124)) == invmod(3, 124)
@test invmod(UInt64(3), UInt64(124)) == invmod(3, 124)
@test invmod(UInt128(3), UInt128(124)) == invmod(3, 124)

@test invmod(Int8(3), Int8(124)) == invmod(3, 124)
@test invmod(Int16(3), Int16(124)) == invmod(3, 124)
@test invmod(Int32(3), Int32(124)) == invmod(3, 124)
@test invmod(Int64(3), Int64(124)) == invmod(3, 124)
@test invmod(Int128(3), Int128(124)) == invmod(3, 124)

for T in (Int8, UInt8)
for x in typemin(T) : typemax(T)
for m in typemin(T) : typemax(T)
if m != 0 && try gcdx(x, m)[1] == 1 catch _ true end
y = invmod(x, m)
@test mod(widemul(y, x), m) == mod(1,m)
@test div(y,m) == 0
else
@test_throws DomainError invmod(x, m)
end
end
end
end
end

@testset "powermod" begin
Expand Down

0 comments on commit 9e993be

Please sign in to comment.