Skip to content

Commit 78af72f

Browse files
authored
Faster Rationals by avoiding unnecessary divgcd (#35492)
1 parent 5ec0608 commit 78af72f

File tree

2 files changed

+94
-47
lines changed

2 files changed

+94
-47
lines changed

base/rational.jl

Lines changed: 80 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,34 @@ struct Rational{T<:Integer} <: Real
1010
num::T
1111
den::T
1212

13-
function Rational{T}(num::Integer, den::Integer) where T<:Integer
14-
num == den == zero(T) && __throw_rational_argerror_zero(T)
15-
num2, den2 = divgcd(num, den)
16-
if T<:Signed && signbit(den2)
17-
den2 = -den2
18-
signbit(den2) && __throw_rational_argerror_typemin(T)
19-
num2 = -num2
20-
end
21-
return new(num2, den2)
13+
# Unexported inner constructor of Rational that bypasses all checks
14+
global unsafe_rational(::Type{T}, num, den) where {T} = new{T}(num, den)
15+
end
16+
17+
unsafe_rational(num::T, den::T) where {T<:Integer} = unsafe_rational(T, num, den)
18+
unsafe_rational(num::Integer, den::Integer) = unsafe_rational(promote(num, den)...)
19+
20+
@noinline __throw_rational_argerror_typemin(T) = throw(ArgumentError("invalid rational: denominator can't be typemin($T)"))
21+
function checked_den(num::T, den::T) where T<:Integer
22+
if signbit(den)
23+
den = -den
24+
signbit(den) && __throw_rational_argerror_typemin(T)
25+
num = -num
2226
end
27+
return unsafe_rational(T, num, den)
2328
end
29+
checked_den(num::Integer, den::Integer) = checked_den(promote(num, den)...)
30+
2431
@noinline __throw_rational_argerror_zero(T) = throw(ArgumentError("invalid rational: zero($T)//zero($T)"))
25-
@noinline __throw_rational_argerror_typemin(T) = throw(ArgumentError("invalid rational: denominator can't be typemin($T)"))
32+
function Rational{T}(num::Integer, den::Integer) where T<:Integer
33+
iszero(den) && iszero(num) && __throw_rational_argerror_zero(T)
34+
num, den = divgcd(num, den)
35+
return checked_den(T(num), T(den))
36+
end
2637

27-
Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n,d)
28-
Rational(n::Integer, d::Integer) = Rational(promote(n,d)...)
29-
Rational(n::Integer) = Rational(n,one(n))
38+
Rational(n::T, d::T) where {T<:Integer} = Rational{T}(n, d)
39+
Rational(n::Integer, d::Integer) = Rational(promote(n, d)...)
40+
Rational(n::Integer) = unsafe_rational(n, one(n))
3041

3142
function divgcd(x::Integer,y::Integer)
3243
g = gcd(x,y)
@@ -50,20 +61,20 @@ julia> (3 // 5) // (2 // 1)
5061
//(n::Integer, d::Integer) = Rational(n,d)
5162

5263
function //(x::Rational, y::Integer)
53-
xn,yn = divgcd(x.num,y)
54-
xn//checked_mul(x.den,yn)
64+
xn, yn = divgcd(x.num,y)
65+
checked_den(xn, checked_mul(x.den, yn))
5566
end
5667
function //(x::Integer, y::Rational)
57-
xn,yn = divgcd(x,y.num)
58-
checked_mul(xn,y.den)//yn
68+
xn, yn = divgcd(x,y.num)
69+
checked_den(checked_mul(xn, y.den), yn)
5970
end
6071
function //(x::Rational, y::Rational)
6172
xn,yn = divgcd(x.num,y.num)
6273
xd,yd = divgcd(x.den,y.den)
63-
checked_mul(xn,yd)//checked_mul(xd,yn)
74+
checked_den(checked_mul(xn, yd), checked_mul(xd, yn))
6475
end
6576

66-
//(x::Complex, y::Real) = complex(real(x)//y,imag(x)//y)
77+
//(x::Complex, y::Real) = complex(real(x)//y, imag(x)//y)
6778
//(x::Number, y::Complex) = x*conj(y)//abs2(y)
6879

6980

@@ -84,8 +95,12 @@ function write(s::IO, z::Rational)
8495
write(s,numerator(z),denominator(z))
8596
end
8697

87-
Rational{T}(x::Rational) where {T<:Integer} = Rational{T}(convert(T,x.num), convert(T,x.den))
88-
Rational{T}(x::Integer) where {T<:Integer} = Rational{T}(convert(T,x), convert(T,1))
98+
function Rational{T}(x::Rational) where T<:Integer
99+
unsafe_rational(T, convert(T, x.num), convert(T, x.den))
100+
end
101+
function Rational{T}(x::Integer) where T<:Integer
102+
unsafe_rational(T, convert(T, x), one(T))
103+
end
89104

90105
Rational(x::Rational) = x
91106

@@ -108,7 +123,7 @@ end
108123
Rational(x::Float64) = Rational{Int64}(x)
109124
Rational(x::Float32) = Rational{Int}(x)
110125

111-
big(q::Rational) = big(numerator(q))//big(denominator(q))
126+
big(q::Rational) = unsafe_rational(big(numerator(q)), big(denominator(q)))
112127

113128
big(z::Complex{<:Rational{<:Integer}}) = Complex{Rational{BigInt}}(z)
114129

@@ -118,6 +133,8 @@ promote_rule(::Type{Rational{T}}, ::Type{S}) where {T<:Integer,S<:AbstractFloat}
118133

119134
widen(::Type{Rational{T}}) where {T} = Rational{widen(T)}
120135

136+
@noinline __throw_negate_unsigned() = throw(OverflowError("cannot negate unsigned number"))
137+
121138
"""
122139
rationalize([T<:Integer=Int,] x; tol::Real=eps(x))
123140
@@ -140,8 +157,9 @@ function rationalize(::Type{T}, x::AbstractFloat, tol::Real) where T<:Integer
140157
if tol < 0
141158
throw(ArgumentError("negative tolerance $tol"))
142159
end
160+
T<:Unsigned && x < 0 && __throw_negate_unsigned()
143161
isnan(x) && return T(x)//one(T)
144-
isinf(x) && return (x < 0 ? -one(T) : one(T))//zero(T)
162+
isinf(x) && return unsafe_rational(x < 0 ? -one(T) : one(T), zero(T))
145163

146164
p, q = (x < 0 ? -one(T) : one(T)), zero(T)
147165
pp, qq = zero(T), one(T)
@@ -234,59 +252,74 @@ denominator(x::Rational) = x.den
234252

235253
sign(x::Rational) = oftype(x, sign(x.num))
236254
signbit(x::Rational) = signbit(x.num)
237-
copysign(x::Rational, y::Real) = copysign(x.num,y) // x.den
238-
copysign(x::Rational, y::Rational) = copysign(x.num,y.num) // x.den
255+
copysign(x::Rational, y::Real) = unsafe_rational(copysign(x.num, y), x.den)
256+
copysign(x::Rational, y::Rational) = unsafe_rational(copysign(x.num, y.num), x.den)
239257

240258
abs(x::Rational) = Rational(abs(x.num), x.den)
241259

242-
typemin(::Type{Rational{T}}) where {T<:Integer} = -one(T)//zero(T)
243-
typemax(::Type{Rational{T}}) where {T<:Integer} = one(T)//zero(T)
260+
typemin(::Type{Rational{T}}) where {T<:Signed} = unsafe_rational(T, -one(T), zero(T))
261+
typemin(::Type{Rational{T}}) where {T<:Integer} = unsafe_rational(T, zero(T), one(T))
262+
typemax(::Type{Rational{T}}) where {T<:Integer} = unsafe_rational(T, one(T), zero(T))
244263

245264
isinteger(x::Rational) = x.den == 1
246265

247-
+(x::Rational) = (+x.num) // x.den
248-
-(x::Rational) = (-x.num) // x.den
266+
+(x::Rational) = unsafe_rational(+x.num, x.den)
267+
-(x::Rational) = unsafe_rational(-x.num, x.den)
249268

250269
function -(x::Rational{T}) where T<:BitSigned
251-
x.num == typemin(T) && throw(OverflowError("rational numerator is typemin(T)"))
252-
(-x.num) // x.den
270+
x.num == typemin(T) && __throw_rational_numerator_typemin(T)
271+
unsafe_rational(-x.num, x.den)
253272
end
273+
@noinline __throw_rational_numerator_typemin(T) = throw(OverflowError("rational numerator is typemin($T)"))
274+
254275
function -(x::Rational{T}) where T<:Unsigned
255-
x.num != zero(T) && throw(OverflowError("cannot negate unsigned number"))
276+
x.num != zero(T) && __throw_negate_unsigned()
256277
x
257278
end
258279

259-
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub),
260-
(:rem,:rem), (:mod,:mod))
280+
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub), (:rem,:rem), (:mod,:mod))
261281
@eval begin
262282
function ($op)(x::Rational, y::Rational)
263283
xd, yd = divgcd(x.den, y.den)
264284
Rational(($chop)(checked_mul(x.num,yd), checked_mul(y.num,xd)), checked_mul(x.den,yd))
265285
end
266286

267287
function ($op)(x::Rational, y::Integer)
268-
Rational(($chop)(x.num, checked_mul(x.den, y)), x.den)
288+
unsafe_rational(($chop)(x.num, checked_mul(x.den, y)), x.den)
269289
end
270-
290+
end
291+
end
292+
for (op,chop) in ((:+,:checked_add), (:-,:checked_sub))
293+
@eval begin
294+
function ($op)(y::Integer, x::Rational)
295+
unsafe_rational(($chop)(checked_mul(x.den, y), x.num), x.den)
296+
end
297+
end
298+
end
299+
for (op,chop) in ((:rem,:rem), (:mod,:mod))
300+
@eval begin
271301
function ($op)(y::Integer, x::Rational)
272302
Rational(($chop)(checked_mul(x.den, y), x.num), x.den)
273303
end
274304
end
275305
end
276306

277307
function *(x::Rational, y::Rational)
278-
xn,yd = divgcd(x.num,y.den)
279-
xd,yn = divgcd(x.den,y.num)
280-
checked_mul(xn,yn) // checked_mul(xd,yd)
308+
xn, yd = divgcd(x.num, y.den)
309+
xd, yn = divgcd(x.den, y.num)
310+
unsafe_rational(checked_mul(xn, yn), checked_mul(xd, yd))
281311
end
282312
function *(x::Rational, y::Integer)
283313
xd, yn = divgcd(x.den, y)
284-
checked_mul(x.num, yn) // xd
314+
unsafe_rational(checked_mul(x.num, yn), xd)
315+
end
316+
function *(y::Integer, x::Rational)
317+
yn, xd = divgcd(y, x.den)
318+
unsafe_rational(checked_mul(yn, x.num), xd)
285319
end
286-
*(x::Integer, y::Rational) = *(y, x)
287-
/(x::Rational, y::Rational) = x//y
288-
/(x::Rational, y::Complex{<:Union{Integer,Rational}}) = x//y
289-
inv(x::Rational) = Rational(x.den, x.num)
320+
/(x::Rational, y::Union{Rational, Integer, Complex{<:Union{Integer,Rational}}}) = x//y
321+
/(x::Union{Integer, Complex{<:Union{Integer,Rational}}}, y::Rational) = x//y
322+
inv(x::Rational{T}) where {T} = checked_den(x.den, x.num)
290323

291324
fma(x::Rational, y::Rational, z::Rational) = x*y+z
292325

@@ -403,7 +436,7 @@ round(x::Rational, r::RoundingMode=RoundNearest) = round(typeof(x), x, r)
403436

404437
function round(::Type{T}, x::Rational{Tr}, r::RoundingMode=RoundNearest) where {T,Tr}
405438
if iszero(denominator(x)) && !(T <: Integer)
406-
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
439+
return convert(T, copysign(unsafe_rational(one(Tr), zero(Tr)), numerator(x)))
407440
end
408441
convert(T, div(numerator(x), denominator(x), r))
409442
end
@@ -437,8 +470,8 @@ end
437470

438471
float(::Type{Rational{T}}) where {T<:Integer} = float(T)
439472

440-
gcd(x::Rational, y::Rational) = gcd(x.num, y.num) // lcm(x.den, y.den)
441-
lcm(x::Rational, y::Rational) = lcm(x.num, y.num) // gcd(x.den, y.den)
473+
gcd(x::Rational, y::Rational) = unsafe_rational(gcd(x.num, y.num), lcm(x.den, y.den))
474+
lcm(x::Rational, y::Rational) = unsafe_rational(lcm(x.num, y.num), gcd(x.den, y.den))
442475
function gcdx(x::Rational, y::Rational)
443476
c = gcd(x, y)
444477
if iszero(c.num)

test/rational.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using Test
1717
@test 5//0 == 1//0
1818
@test -1//0 == -1//0
1919
@test -7//0 == -1//0
20+
@test (-1//2) // (-2//5) == 5//4
2021

2122
@test_throws OverflowError -(0x01//0x0f)
2223
@test_throws OverflowError -(typemin(Int)//1)
@@ -26,9 +27,13 @@ using Test
2627
@test (typemax(Int)//1) / (typemax(Int)//1) == 1
2728
@test (1//typemax(Int)) / (1//typemax(Int)) == 1
2829
@test_throws OverflowError (1//2)^63
30+
@test inv((1+typemin(Int))//typemax(Int)) == -1
31+
@test_throws ArgumentError inv(typemin(Int)//typemax(Int))
32+
@test_throws ArgumentError Rational(0x1, typemin(Int32))
2933

3034
@test @inferred(rationalize(Int, 3.0, 0.0)) === 3//1
3135
@test @inferred(rationalize(Int, 3.0, 0)) === 3//1
36+
@test_throws OverflowError rationalize(UInt, -2.0)
3237
@test_throws ArgumentError rationalize(Int, big(3.0), -1.)
3338
# issue 26823
3439
@test_throws InexactError rationalize(Int, NaN)
@@ -38,6 +43,11 @@ using Test
3843
@test -2 // typemin(Int) == -1 // (typemin(Int) >> 1)
3944
@test 2 // typemin(Int) == 1 // (typemin(Int) >> 1)
4045

46+
@test_throws InexactError Rational(UInt(1), typemin(Int32))
47+
@test iszero(Rational{Int}(UInt(0), 1))
48+
@test Rational{BigInt}(UInt(1), Int(-1)) == -1
49+
@test_broken Rational{Int64}(UInt(1), typemin(Int32)) == Int64(1) // Int64(typemin(Int32))
50+
4151
for a = -5:5, b = -5:5
4252
if a == b == 0; continue; end
4353
if ispow2(b)
@@ -120,6 +130,8 @@ end
120130
@test widen(Rational{T}) == Rational{widen(T)}
121131
end
122132

133+
@test iszero(typemin(Rational{UInt}))
134+
123135
@test Rational(Float32(rand_int)) == Rational(rand_int)
124136

125137
@test Rational(Rational(rand_int)) == Rational(rand_int)
@@ -548,6 +560,8 @@ end
548560
end
549561
@test 1//2 * 3 == 3//2
550562
@test -3 * (1//2) == -3//2
563+
@test (6//5) // -3 == -2//5
564+
@test -4 // (-6//5) == 10//3
551565

552566
@test_throws OverflowError UInt(1)//2 - 1
553567
@test_throws OverflowError 1 - UInt(5)//2

0 commit comments

Comments
 (0)