Skip to content

Commit 9e993be

Browse files
sam0410simonbyrnesam0410
authored
Fix invmod giving wrong results for moduli close to typemax (#30515)
Fixes #29971 Co-authored-by: Simon Byrne <simon.byrne@gmail.com> Co-authored-by: sam0410 <samikshya.chand.ece15@iitbhu.ac.in>
1 parent 665279a commit 9e993be

File tree

2 files changed

+58
-16
lines changed

2 files changed

+58
-16
lines changed

base/intfuncs.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,21 @@ lcm(abc::AbstractArray{<:Real}) = reduce(lcm, abc; init=one(eltype(abc)))
134134
function gcd(abc::AbstractArray{<:Integer})
135135
a = zero(eltype(abc))
136136
for b in abc
137-
a = gcd(a,b)
137+
a = gcd(a, b)
138138
if a == 1
139139
return a
140140
end
141141
end
142142
return a
143143
end
144144

145-
# return (gcd(a,b),x,y) such that ax+by == gcd(a,b)
145+
# return (gcd(a, b), x, y) such that ax+by == gcd(a, b)
146146
"""
147-
gcdx(x,y)
147+
gcdx(a, b)
148148
149-
Computes the greatest common (positive) divisor of `x` and `y` and their Bézout
149+
Computes the greatest common (positive) divisor of `a` and `b` and their Bézout
150150
coefficients, i.e. the integer coefficients `u` and `v` that satisfy
151-
``ux+vy = d = gcd(x,y)``. ``gcdx(x,y)`` returns ``(d,u,v)``.
151+
``ua+vb = d = gcd(a, b)``. ``gcdx(a, b)`` returns ``(d, u, v)``.
152152
153153
The arguments may be integer and rational numbers.
154154
@@ -175,8 +175,8 @@ julia> gcdx(240, 46)
175175
their `typemax`, and the identity then holds only via the unsigned
176176
integers' modulo arithmetic.
177177
"""
178-
function gcdx(a::U, b::V) where {U<:Integer, V<:Integer}
179-
T = promote_type(U, V)
178+
function gcdx(a::Integer, b::Integer)
179+
T = promote_type(typeof(a), typeof(b))
180180
# a0, b0 = a, b
181181
s0, s1 = oneunit(T), zero(T)
182182
t0, t1 = s1, s0
@@ -197,11 +197,11 @@ gcdx(a::T, b::T) where T<:Real = throw(MethodError(gcdx, (a,b)))
197197
# multiplicative inverse of n mod m, error if none
198198

199199
"""
200-
invmod(x,m)
200+
invmod(n, m)
201201
202-
Take the inverse of `x` modulo `m`: `y` such that ``x y = 1 \\pmod m``,
203-
with ``div(x,y) = 0``. This is undefined for ``m = 0``, or if
204-
``gcd(x,m) \\neq 1``.
202+
Take the inverse of `n` modulo `m`: `y` such that ``n y = 1 \\pmod m``,
203+
and ``div(y,m) = 0``. This will throw an error if ``m = 0``, or if
204+
``gcd(n,m) \\neq 1``.
205205
206206
# Examples
207207
```jldoctest
@@ -216,14 +216,24 @@ julia> invmod(5,6)
216216
```
217217
"""
218218
function invmod(n::Integer, m::Integer)
219+
iszero(m) && throw(DomainError(m, "`m` must not be 0."))
220+
if n isa Signed
221+
# work around inconsistencies in gcdx
222+
# https://github.com/JuliaLang/julia/issues/33781
223+
T = promote_type(typeof(n), typeof(m))
224+
n == typemin(typeof(n)) && m == typeof(n)(-1) && return T(0)
225+
n == typeof(n)(-1) && m == typemin(typeof(n)) && return T(-1)
226+
end
219227
g, x, y = gcdx(n, m)
220228
g != 1 && throw(DomainError((n, m), "Greatest common divisor is $g."))
221-
m == 0 && throw(DomainError(m, "`m` must not be 0."))
222229
# Note that m might be negative here.
223-
# For unsigned T, x might be close to typemax; add m to force a wrap-around.
224-
r = mod(x + m, m)
225-
# The postcondition is: mod(r * n, m) == mod(T(1), m) && div(r, m) == 0
226-
r
230+
if n isa Unsigned && hastypemax(typeof(n)) && x > typemax(n)>>1
231+
# x might have wrapped if it would have been negative
232+
# adding back m forces a correction
233+
x += m
234+
end
235+
# The postcondition is: mod(result * n, m) == mod(T(1), m) && div(result, m) == 0
236+
return mod(x, m)
227237
end
228238

229239
# ^ for any x supporting *

test/intfuncs.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,38 @@ end
212212
@test invmod(2, 0x3) == 2
213213
@test invmod(0x8, -3) == -1
214214
@test_throws DomainError invmod(0, 3)
215+
216+
# For issue 29971
217+
@test invmod(UInt8(1), typemax(UInt8)) == invmod(UInt16(1), UInt16(typemax(UInt8)))
218+
@test invmod(UInt16(1), typemax(UInt16)) == invmod(UInt32(1), UInt32(typemax(UInt16)))
219+
@test invmod(UInt32(1), typemax(UInt32)) == invmod(UInt64(1), UInt64(typemax(UInt32)))
220+
@test invmod(UInt64(1), typemax(UInt64)) == invmod(UInt128(1), UInt128(typemax(UInt64)))
221+
222+
@test invmod(UInt8(3), UInt8(124)) == invmod(3, 124)
223+
@test invmod(UInt16(3), UInt16(124)) == invmod(3, 124)
224+
@test invmod(UInt32(3), UInt32(124)) == invmod(3, 124)
225+
@test invmod(UInt64(3), UInt64(124)) == invmod(3, 124)
226+
@test invmod(UInt128(3), UInt128(124)) == invmod(3, 124)
227+
228+
@test invmod(Int8(3), Int8(124)) == invmod(3, 124)
229+
@test invmod(Int16(3), Int16(124)) == invmod(3, 124)
230+
@test invmod(Int32(3), Int32(124)) == invmod(3, 124)
231+
@test invmod(Int64(3), Int64(124)) == invmod(3, 124)
232+
@test invmod(Int128(3), Int128(124)) == invmod(3, 124)
233+
234+
for T in (Int8, UInt8)
235+
for x in typemin(T) : typemax(T)
236+
for m in typemin(T) : typemax(T)
237+
if m != 0 && try gcdx(x, m)[1] == 1 catch _ true end
238+
y = invmod(x, m)
239+
@test mod(widemul(y, x), m) == mod(1,m)
240+
@test div(y,m) == 0
241+
else
242+
@test_throws DomainError invmod(x, m)
243+
end
244+
end
245+
end
246+
end
215247
end
216248

217249
@testset "powermod" begin

0 commit comments

Comments
 (0)