Skip to content

Commit 9634652

Browse files
authored
Fix fastmath for vararg +, *, min, max methods (#54513)
Currently using the fastmath vararg +, *, min, max methods only actually sets fastmath if they are specifically overloaded even when the correct 2 argument methods have been defined. As such, `ComplexF32, ComplexF64` do not currently set fastmath when using the vararg methods. This will also fix any other types, such as those in SIMD.jl, which don't overload the vararg methods. E.g. ```julia x = ComplexF64(1) f(x) = @fastmath x + x + x ``` now works correctly. I see no reason why the vararg methods shouldn't default to using the fastmath 2 argument methods instead of the non fastmath ones, which is the current behaviour. I also switched the implementation to use `afoldl` as that's what the non fastmath vararg methods use. Fixes #54456 and eschnett/SIMD.jl#108.
1 parent fe35189 commit 9634652

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

base/fastmath.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export @fastmath
3030
import Core.Intrinsics: sqrt_llvm_fast, neg_float_fast,
3131
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast,
3232
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast
33+
import Base: afoldl
3334

3435
const fast_op =
3536
Dict(# basic arithmetic
@@ -168,11 +169,6 @@ sub_fast(x::T, y::T) where {T<:FloatTypes} = sub_float_fast(x, y)
168169
mul_fast(x::T, y::T) where {T<:FloatTypes} = mul_float_fast(x, y)
169170
div_fast(x::T, y::T) where {T<:FloatTypes} = div_float_fast(x, y)
170171

171-
add_fast(x::T, y::T, zs::T...) where {T<:FloatTypes} =
172-
add_fast(add_fast(x, y), zs...)
173-
mul_fast(x::T, y::T, zs::T...) where {T<:FloatTypes} =
174-
mul_fast(mul_fast(x, y), zs...)
175-
176172
@fastmath begin
177173
cmp_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(x==y, 0, ifelse(x<y, -1, +1))
178174
log_fast(b::T, x::T) where {T<:FloatTypes} = log_fast(x)/log_fast(b)
@@ -245,9 +241,6 @@ ComplexTypes = Union{ComplexF32, ComplexF64}
245241
max_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, y, x)
246242
min_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, x, y)
247243
minmax_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, (x,y), (y,x))
248-
249-
max_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = max_fast(max_fast(x, y), z...)
250-
min_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = min_fast(min_fast(x, y), z...)
251244
end
252245

253246
# fall-back implementations and type promotion
@@ -260,7 +253,7 @@ for op in (:abs, :abs2, :conj, :inv, :sign)
260253
end
261254
end
262255

263-
for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem, :min, :max, :minmax)
256+
for op in (:-, :/, :(==), :!=, :<, :<=, :cmp, :rem, :minmax)
264257
op_fast = fast_op[op]
265258
@eval begin
266259
# fall-back implementation for non-numeric types
@@ -273,6 +266,31 @@ for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem, :min, :max, :minmax)
273266
end
274267
end
275268

269+
for op in (:+, :*, :min, :max)
270+
op_fast = fast_op[op]
271+
@eval begin
272+
$op_fast(x) = $op(x)
273+
# fall-back implementation for non-numeric types
274+
$op_fast(x, y) = $op(x, y)
275+
# type promotion
276+
$op_fast(x::Number, y::Number) =
277+
$op_fast(promote(x,y)...)
278+
# fall-back implementation that applies after promotion
279+
$op_fast(x::T,y::T) where {T<:Number} = $op(x,y)
280+
# note: these definitions must not cause a dispatch loop when +(a,b) is
281+
# not defined, and must only try to call 2-argument definitions, so
282+
# that defining +(a,b) is sufficient for full functionality.
283+
($op_fast)(a, b, c, xs...) = (@inline; afoldl($op_fast, ($op_fast)(($op_fast)(a,b),c), xs...))
284+
# a further concern is that it's easy for a type like (Int,Int...)
285+
# to match many definitions, so we need to keep the number of
286+
# definitions down to avoid losing type information.
287+
# type promotion
288+
$op_fast(a::Number, b::Number, c::Number, xs::Number...) =
289+
$op_fast(promote(x,y,c,xs...)...)
290+
# fall-back implementation that applies after promotion
291+
$op_fast(a::T, b::T, c::T, xs::T...) where {T<:Number} = (@inline; afoldl($op_fast, ($op_fast)(($op_fast)(a,b),c), xs...))
292+
end
293+
end
276294

277295
# Math functions
278296
exp2_fast(x::Union{Float32,Float64}) = Base.Math.exp2_fast(x)

test/fastmath.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,30 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3+
using InteractiveUtils: code_llvm
34
# fast math
45

6+
@testset "check fast present in LLVM" begin
7+
for T in (Float16, Float32, Float64, ComplexF32, ComplexF64)
8+
f(x) = @fastmath x + x + x
9+
llvm = sprint(code_llvm, f, (T,))
10+
@test occursin("fast", llvm)
11+
12+
g(x) = @fastmath x * x * x
13+
llvm = sprint(code_llvm, g, (T,))
14+
@test occursin("fast", llvm)
15+
end
16+
17+
for T in (Float16, Float32, Float64)
18+
f(x, y, z) = @fastmath min(x, y, z)
19+
llvm = sprint(code_llvm, f, (T,T,T))
20+
@test occursin("fast", llvm)
21+
22+
g(x, y, z) = @fastmath max(x, y, z)
23+
llvm = sprint(code_llvm, g, (T,T,T))
24+
@test occursin("fast", llvm)
25+
end
26+
end
27+
528
@testset "check expansions" begin
629
@test macroexpand(Main, :(@fastmath 1+2)) == :(Base.FastMath.add_fast(1,2))
730
@test macroexpand(Main, :(@fastmath +)) == :(Base.FastMath.add_fast)

0 commit comments

Comments
 (0)