-
Notifications
You must be signed in to change notification settings - Fork 47
Improvements to float intrinsics #531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl
index 78f5d2db..b6aace77 100644
--- a/test/device/intrinsics.jl
+++ b/test/device/intrinsics.jl
@@ -175,191 +175,191 @@ MATH_INTR_FUNCS_3_ARG = [
]
@testset "math" begin
-# 1-arg functions
-@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
- cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
- rand(T, 4)
- else
- T[0.0, -0.0, rand(T), -rand(T)]
- end
+ # 1-arg functions
+ @testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
+ cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
+ rand(T, 4)
+ else
+ T[0.0, -0.0, rand(T), -rand(T)]
+ end
- mtlarr = MtlArray(cpuarr)
+ mtlarr = MtlArray(cpuarr)
- mtlout = fill!(similar(mtlarr), 0)
+ mtlout = fill!(similar(mtlarr), 0)
- function kernel(res, arr)
+ function kernel(res, arr)
idx = thread_position_in_grid_1d()
- res[idx] = fun(arr[idx])
+ res[idx] = fun(arr[idx])
return nothing
end
- Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
- @eval @test Array($mtlout) ≈ $fun.($cpuarr)
-end
-# 2-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
- N = 4
- arr1 = randn(T, N)
- arr2 = randn(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
+ Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
+ @eval @test Array($mtlout) ≈ $fun.($cpuarr)
+ end
+ # 2-arg functions
+ @testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
+ N = 4
+ arr1 = randn(T, N)
+ arr2 = randn(T, N)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y)
+ function kernel(res, x, y)
idx = thread_position_in_grid_1d()
- res[idx] = fun(x[idx], y[idx])
+ res[idx] = fun(x[idx], y[idx])
return nothing
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
- @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
-end
-# 3-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
- N = 4
- arr1 = randn(T, N)
- arr2 = randn(T, N)
- arr3 = randn(T, N)
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+ @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
+ end
+ # 3-arg functions
+ @testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
+ N = 4
+ arr1 = randn(T, N)
+ arr2 = randn(T, N)
+ arr3 = randn(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
- mtlarr3 = MtlArray(arr3)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
+ mtlarr3 = MtlArray(arr3)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y, z)
+ function kernel(res, x, y, z)
idx = thread_position_in_grid_1d()
- res[idx] = fun(x[idx], y[idx], z[idx])
+ res[idx] = fun(x[idx], y[idx], z[idx])
return nothing
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
- @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
-end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
+ @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
+ end
end
@testset "unique math" begin
-@testset "$T" for T in (Float32, Float16)
- let # acosh
- arr = T[0, rand(T, 3)...] .+ T(1)
- buffer = MtlArray(arr)
- vec = acosh.(buffer)
- @test Array(vec) ≈ acosh.(arr)
- end
-
- let # sincos
- N = 4
- arr = rand(T, N)
- bufferA = MtlArray(arr)
- bufferB = MtlArray(arr)
- function intr_test3(arr_sin, arr_cos)
- idx = thread_position_in_grid_1d()
- sinres, cosres = sincos(arr_cos[idx])
- arr_sin[idx] = sinres
- arr_cos[idx] = cosres
- return nothing
- end
- # Broken with Float16
- if T == Float16
- @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
- else
- Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
- @test Array(bufferA) ≈ sin.(arr)
- @test Array(bufferB) ≈ cos.(arr)
+ @testset "$T" for T in (Float32, Float16)
+ let # acosh
+ arr = T[0, rand(T, 3)...] .+ T(1)
+ buffer = MtlArray(arr)
+ vec = acosh.(buffer)
+ @test Array(vec) ≈ acosh.(arr)
+ end
+
+ let # sincos
+ N = 4
+ arr = rand(T, N)
+ bufferA = MtlArray(arr)
+ bufferB = MtlArray(arr)
+ function intr_test3(arr_sin, arr_cos)
+ idx = thread_position_in_grid_1d()
+ sinres, cosres = sincos(arr_cos[idx])
+ arr_sin[idx] = sinres
+ arr_cos[idx] = cosres
+ return nothing
+ end
+ # Broken with Float16
+ if T == Float16
+ @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
+ else
+ Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
+ @test Array(bufferA) ≈ sin.(arr)
+ @test Array(bufferB) ≈ cos.(arr)
+ end
end
- end
- let # clamp
- N = 4
- in = randn(T, N)
- minval = fill(T(-0.6), N)
- maxval = fill(T(0.6), N)
+ let # clamp
+ N = 4
+ in = randn(T, N)
+ minval = fill(T(-0.6), N)
+ maxval = fill(T(0.6), N)
- mtlin = MtlArray(in)
- mtlminval = MtlArray(minval)
- mtlmaxval = MtlArray(maxval)
+ mtlin = MtlArray(in)
+ mtlminval = MtlArray(minval)
+ mtlmaxval = MtlArray(maxval)
- mtlout = fill!(similar(mtlin), 0)
+ mtlout = fill!(similar(mtlin), 0)
- function kernel(res, x, y, z)
- idx = thread_position_in_grid_1d()
- res[idx] = clamp(x[idx], y[idx], z[idx])
- return nothing
+ function kernel(res, x, y, z)
+ idx = thread_position_in_grid_1d()
+ res[idx] = clamp(x[idx], y[idx], z[idx])
+ return nothing
+ end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
+ @test Array(mtlout) == clamp.(in, minval, maxval)
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
- @test Array(mtlout) == clamp.(in, minval, maxval)
- end
- let #pow
- N = 4
- arr1 = rand(T, N)
- arr2 = rand(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
+ let #pow
+ N = 4
+ arr1 = rand(T, N)
+ arr2 = rand(T, N)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y)
- idx = thread_position_in_grid_1d()
- res[idx] = x[idx]^y[idx]
- return nothing
+ function kernel(res, x, y)
+ idx = thread_position_in_grid_1d()
+ res[idx] = x[idx]^y[idx]
+ return nothing
+ end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+ @test Array(mtlout) ≈ arr1 .^ arr2
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
- @test Array(mtlout) ≈ arr1 .^ arr2
- end
- let #powr
- N = 4
- arr1 = rand(T, N)
- arr2 = rand(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
+ let #powr
+ N = 4
+ arr1 = rand(T, N)
+ arr2 = rand(T, N)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y)
- idx = thread_position_in_grid_1d()
- res[idx] = Metal.powr(x[idx], y[idx])
- return nothing
+ function kernel(res, x, y)
+ idx = thread_position_in_grid_1d()
+ res[idx] = Metal.powr(x[idx], y[idx])
+ return nothing
+ end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+ @test Array(mtlout) ≈ arr1 .^ arr2
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
- @test Array(mtlout) ≈ arr1 .^ arr2
- end
- let # log1p
- arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
- buffer = MtlArray(arr)
- vec = Array(log1p.(buffer))
- @test vec ≈ log1p.(arr)
- end
+ let # log1p
+ arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
+ buffer = MtlArray(arr)
+ vec = Array(log1p.(buffer))
+ @test vec ≈ log1p.(arr)
+ end
- let # erf
- arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
- buffer = MtlArray(arr)
- vec = Array(SpecialFunctions.erf.(buffer))
- @test vec ≈ SpecialFunctions.erf.(arr)
- end
+ let # erf
+ arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
+ buffer = MtlArray(arr)
+ vec = Array(SpecialFunctions.erf.(buffer))
+ @test vec ≈ SpecialFunctions.erf.(arr)
+ end
- let # erfc
- arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
- buffer = MtlArray(arr)
- vec = Array(SpecialFunctions.erfc.(buffer))
- @test vec ≈ SpecialFunctions.erfc.(arr)
- end
+ let # erfc
+ arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
+ buffer = MtlArray(arr)
+ vec = Array(SpecialFunctions.erfc.(buffer))
+ @test vec ≈ SpecialFunctions.erfc.(arr)
+ end
- let # erfinv
- arr = collect(LinRange(-1.0f0, 1.0f0, 20))
- buffer = MtlArray(arr)
- vec = Array(SpecialFunctions.erfinv.(buffer))
- @test vec ≈ SpecialFunctions.erfinv.(arr)
- end
+ let # erfinv
+ arr = collect(LinRange(-1.0f0, 1.0f0, 20))
+ buffer = MtlArray(arr)
+ vec = Array(SpecialFunctions.erfinv.(buffer))
+ @test vec ≈ SpecialFunctions.erfinv.(arr)
+ end
- let # expm1
- arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
- buffer = MtlArray(arr)
- vec = Array(expm1.(buffer))
- @test vec ≈ expm1.(arr)
+ let # expm1
+ arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
+ buffer = MtlArray(arr)
+ vec = Array(expm1.(buffer))
+ @test vec ≈ expm1.(arr)
+ end
end
end
-end
############################################################################################
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metal Benchmarks
Benchmark suite | Current: 024b990 | Previous: ca092c8 | Ratio |
---|---|---|---|
private array/construct |
24663.166666666668 ns |
24829.916666666664 ns |
0.99 |
private array/broadcast |
460750 ns |
458500 ns |
1.00 |
private array/random/randn/Float32 |
799000 ns |
798750 ns |
1.00 |
private array/random/randn!/Float32 |
644458 ns |
615041.5 ns |
1.05 |
private array/random/rand!/Int64 |
573083 ns |
563000 ns |
1.02 |
private array/random/rand!/Float32 |
597167 ns |
598021 ns |
1.00 |
private array/random/rand/Int64 |
768125 ns |
774083 ns |
0.99 |
private array/random/rand/Float32 |
599250 ns |
611583 ns |
0.98 |
private array/copyto!/gpu_to_gpu |
591396 ns |
654250 ns |
0.90 |
private array/copyto!/cpu_to_gpu |
817042 ns |
624208 ns |
1.31 |
private array/copyto!/gpu_to_cpu |
584041.5 ns |
817708 ns |
0.71 |
private array/accumulate/1d |
1337813 ns |
1329333 ns |
1.01 |
private array/accumulate/2d |
1390083 ns |
1398375 ns |
0.99 |
private array/iteration/findall/int |
2066083.5 ns |
2103583.5 ns |
0.98 |
private array/iteration/findall/bool |
1837604.5 ns |
1824375 ns |
1.01 |
private array/iteration/findfirst/int |
1700958 ns |
1688792 ns |
1.01 |
private array/iteration/findfirst/bool |
1668792 ns |
1643000 ns |
1.02 |
private array/iteration/scalar |
3888833 ns |
3772458 ns |
1.03 |
private array/iteration/logical |
3191625 ns |
3187749.5 ns |
1.00 |
private array/iteration/findmin/1d |
1766208 ns |
1760708 ns |
1.00 |
private array/iteration/findmin/2d |
1351583 ns |
1344437.5 ns |
1.01 |
private array/reductions/reduce/1d |
1018542 ns |
1031583 ns |
0.99 |
private array/reductions/reduce/2d |
663667 ns |
654750 ns |
1.01 |
private array/reductions/mapreduce/1d |
1046750 ns |
1033875 ns |
1.01 |
private array/reductions/mapreduce/2d |
657792 ns |
659000 ns |
1.00 |
private array/permutedims/4d |
2536021 ns |
2503500 ns |
1.01 |
private array/permutedims/2d |
1003562.5 ns |
1028750 ns |
0.98 |
private array/permutedims/3d |
1579250 ns |
1580708 ns |
1.00 |
private array/copy |
608791 ns |
590270.5 ns |
1.03 |
latency/precompile |
8824701958 ns |
8811389416 ns |
1.00 |
latency/ttfp |
3604612541 ns |
3608628500 ns |
1.00 |
latency/import |
1232708333 ns |
1231898292 ns |
1.00 |
integration/metaldevrt |
707708.5 ns |
713792 ns |
0.99 |
integration/byval/slices=1 |
1599020.5 ns |
1617854.5 ns |
0.99 |
integration/byval/slices=3 |
8769500 ns |
9687812.5 ns |
0.91 |
integration/byval/reference |
1584937.5 ns |
1589625 ns |
1.00 |
integration/byval/slices=2 |
2688125 ns |
2675542 ns |
1.00 |
kernel/indexing |
447541.5 ns |
470792 ns |
0.95 |
kernel/indexing_checked |
458083 ns |
463208 ns |
0.99 |
kernel/launch |
9805.333333333334 ns |
9527.666666666666 ns |
1.03 |
metal/synchronization/stream |
14750 ns |
15125 ns |
0.98 |
metal/synchronization/context |
15167 ns |
14834 ns |
1.02 |
shared array/construct |
27840.333333333332 ns |
24604.166666666668 ns |
1.13 |
shared array/broadcast |
462041 ns |
461166 ns |
1.00 |
shared array/random/randn/Float32 |
810250 ns |
738958.5 ns |
1.10 |
shared array/random/randn!/Float32 |
629750 ns |
633292 ns |
0.99 |
shared array/random/rand!/Int64 |
569854.5 ns |
561625 ns |
1.01 |
shared array/random/rand!/Float32 |
603958 ns |
600416 ns |
1.01 |
shared array/random/rand/Int64 |
777520.5 ns |
778375 ns |
1.00 |
shared array/random/rand/Float32 |
598833 ns |
616000 ns |
0.97 |
shared array/copyto!/gpu_to_gpu |
79375 ns |
79250 ns |
1.00 |
shared array/copyto!/cpu_to_gpu |
79666.5 ns |
82084 ns |
0.97 |
shared array/copyto!/gpu_to_cpu |
79959 ns |
82750 ns |
0.97 |
shared array/accumulate/1d |
1337458 ns |
1335833 ns |
1.00 |
shared array/accumulate/2d |
1375541.5 ns |
1388833 ns |
0.99 |
shared array/iteration/findall/int |
1819666.5 ns |
1871833 ns |
0.97 |
shared array/iteration/findall/bool |
1554229.5 ns |
1569500 ns |
0.99 |
shared array/iteration/findfirst/int |
1397958 ns |
1396916 ns |
1.00 |
shared array/iteration/findfirst/bool |
1370833 ns |
1367500 ns |
1.00 |
shared array/iteration/scalar |
158666 ns |
154834 ns |
1.02 |
shared array/iteration/logical |
2981104 ns |
2987020.5 ns |
1.00 |
shared array/iteration/findmin/1d |
1468459 ns |
1477062.5 ns |
0.99 |
shared array/iteration/findmin/2d |
1366791.5 ns |
1364708 ns |
1.00 |
shared array/reductions/reduce/1d |
723541.5 ns |
731750 ns |
0.99 |
shared array/reductions/reduce/2d |
670166 ns |
666250 ns |
1.01 |
shared array/reductions/mapreduce/1d |
743375 ns |
736667 ns |
1.01 |
shared array/reductions/mapreduce/2d |
664542 ns |
672459 ns |
0.99 |
shared array/permutedims/4d |
2543896 ns |
2493333 ns |
1.02 |
shared array/permutedims/2d |
1021646 ns |
1024646 ns |
1.00 |
shared array/permutedims/3d |
1576084 ns |
1576667 ns |
1.00 |
shared array/copy |
244291.5 ns |
244000 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
src/device/intrinsics/math.jl
Outdated
|
||
@device_function trunc_fast(x::Float32) = ccall("extern air.fast_trunc.f32", llvmcall, Cfloat, (Cfloat,), x) | ||
@device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x) | ||
@device_override Base.trunc(x::Float16) = ccall("extern air.trunc.f16", llvmcall, Float16, (Float16,), x) | ||
|
||
@static if Metal.is_macos(v"14") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as on the other PR wrt. tying this to the macOS version.
If anything, we can punt on the intrinsic itself being version-specific and just ensuring the tests are. Maybe file an issue on GPUCompiler for compile-time error reporting from user code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. In that case if the rest is good I’ll take the nextafter
commit out of this PR and rebase #529 on this PR so that we can merge the new tests and fixes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's been done so I'll merge once tests pass.
ad73350
to
0a78ec7
Compare
`tanpi` is in Julia since 1.10 so allsupported versions have it
Also clean up the different tests
5c22476
to
404d23a
Compare
404d23a
to
024b990
Compare
Replaces #529
Description of each commit in order:
tanpi
has been in base Julia since 1.10, so switch from@device_function
to@device_override
.clamp
&sign
nextafter
atan
max
& 'min'