Skip to content

Improvements to float intrinsics #531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 4, 2025
Merged

Improvements to float intrinsics #531

merged 6 commits into from
Feb 4, 2025

Conversation

christiangnrd
Copy link
Member

@christiangnrd christiangnrd commented Jan 30, 2025

Replaces #529

Description of each commit in order:

  • tanpi has been in base Julia since 1.10, so switch from @device_function to @device_override.
  • Test all currently-defined float intrinsics and slightly refactor the old ones.
  • List the float intrinsics from the metal shading language in the tests.
  • Add intrinsics for clamp & sign
  • Add intrinsics for nextafter
  • Add intrinsics for 2-arg atan
  • Add intrinsics for 3-arg max & 'min'

Copy link
Contributor

github-actions bot commented Jan 30, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl
index 78f5d2db..b6aace77 100644
--- a/test/device/intrinsics.jl
+++ b/test/device/intrinsics.jl
@@ -175,191 +175,191 @@ MATH_INTR_FUNCS_3_ARG = [
 ]
 
 @testset "math" begin
-# 1-arg functions
-@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
-    cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
-        rand(T, 4)
-    else
-        T[0.0, -0.0, rand(T), -rand(T)]
-    end
+    # 1-arg functions
+    @testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
+        cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
+            rand(T, 4)
+        else
+            T[0.0, -0.0, rand(T), -rand(T)]
+        end
 
-    mtlarr = MtlArray(cpuarr)
+        mtlarr = MtlArray(cpuarr)
 
-    mtlout = fill!(similar(mtlarr), 0)
+        mtlout = fill!(similar(mtlarr), 0)
 
-    function kernel(res, arr)
+        function kernel(res, arr)
         idx = thread_position_in_grid_1d()
-        res[idx] = fun(arr[idx])
+            res[idx] = fun(arr[idx])
         return nothing
     end
-    Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
-    @eval @test Array($mtlout) ≈ $fun.($cpuarr)
-end
-# 2-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
-    N = 4
-    arr1 = randn(T, N)
-    arr2 = randn(T, N)
-    mtlarr1 = MtlArray(arr1)
-    mtlarr2 = MtlArray(arr2)
+        Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
+        @eval @test Array($mtlout) ≈ $fun.($cpuarr)
+    end
+    # 2-arg functions
+    @testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
+        N = 4
+        arr1 = randn(T, N)
+        arr2 = randn(T, N)
+        mtlarr1 = MtlArray(arr1)
+        mtlarr2 = MtlArray(arr2)
 
-    mtlout = fill!(similar(mtlarr1), 0)
+        mtlout = fill!(similar(mtlarr1), 0)
 
-    function kernel(res, x, y)
+        function kernel(res, x, y)
         idx = thread_position_in_grid_1d()
-        res[idx] = fun(x[idx], y[idx])
+            res[idx] = fun(x[idx], y[idx])
         return nothing
     end
-    Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
-    @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
-end
-# 3-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
-    N = 4
-    arr1 = randn(T, N)
-    arr2 = randn(T, N)
-    arr3 = randn(T, N)
+        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+        @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
+    end
+    # 3-arg functions
+    @testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
+        N = 4
+        arr1 = randn(T, N)
+        arr2 = randn(T, N)
+        arr3 = randn(T, N)
 
-    mtlarr1 = MtlArray(arr1)
-    mtlarr2 = MtlArray(arr2)
-    mtlarr3 = MtlArray(arr3)
+        mtlarr1 = MtlArray(arr1)
+        mtlarr2 = MtlArray(arr2)
+        mtlarr3 = MtlArray(arr3)
 
-    mtlout = fill!(similar(mtlarr1), 0)
+        mtlout = fill!(similar(mtlarr1), 0)
 
-    function kernel(res, x, y, z)
+        function kernel(res, x, y, z)
         idx = thread_position_in_grid_1d()
-        res[idx] = fun(x[idx], y[idx], z[idx])
+            res[idx] = fun(x[idx], y[idx], z[idx])
         return nothing
     end
-    Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
-    @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
-end
+        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
+        @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
+    end
 end
 
 @testset "unique math" begin
-@testset "$T" for T in (Float32, Float16)
-    let # acosh
-        arr = T[0, rand(T, 3)...] .+ T(1)
-        buffer = MtlArray(arr)
-        vec = acosh.(buffer)
-        @test Array(vec) ≈ acosh.(arr)
-    end
-
-    let # sincos
-        N = 4
-        arr = rand(T, N)
-        bufferA = MtlArray(arr)
-        bufferB = MtlArray(arr)
-        function intr_test3(arr_sin, arr_cos)
-            idx = thread_position_in_grid_1d()
-            sinres, cosres = sincos(arr_cos[idx])
-            arr_sin[idx] = sinres
-            arr_cos[idx] = cosres
-            return nothing
-        end
-        # Broken with Float16
-        if T == Float16
-            @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
-        else
-            Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
-            @test Array(bufferA) ≈ sin.(arr)
-            @test Array(bufferB) ≈ cos.(arr)
+    @testset "$T" for T in (Float32, Float16)
+        let # acosh
+            arr = T[0, rand(T, 3)...] .+ T(1)
+            buffer = MtlArray(arr)
+            vec = acosh.(buffer)
+            @test Array(vec) ≈ acosh.(arr)
+        end
+
+        let # sincos
+            N = 4
+            arr = rand(T, N)
+            bufferA = MtlArray(arr)
+            bufferB = MtlArray(arr)
+            function intr_test3(arr_sin, arr_cos)
+                idx = thread_position_in_grid_1d()
+                sinres, cosres = sincos(arr_cos[idx])
+                arr_sin[idx] = sinres
+                arr_cos[idx] = cosres
+                return nothing
+            end
+            # Broken with Float16
+            if T == Float16
+                @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
+            else
+                Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
+                @test Array(bufferA) ≈ sin.(arr)
+                @test Array(bufferB) ≈ cos.(arr)
+            end
         end
-    end
 
-    let # clamp
-        N = 4
-        in = randn(T, N)
-        minval = fill(T(-0.6), N)
-        maxval = fill(T(0.6), N)
+        let # clamp
+            N = 4
+            in = randn(T, N)
+            minval = fill(T(-0.6), N)
+            maxval = fill(T(0.6), N)
 
-        mtlin = MtlArray(in)
-        mtlminval = MtlArray(minval)
-        mtlmaxval = MtlArray(maxval)
+            mtlin = MtlArray(in)
+            mtlminval = MtlArray(minval)
+            mtlmaxval = MtlArray(maxval)
 
-        mtlout = fill!(similar(mtlin), 0)
+            mtlout = fill!(similar(mtlin), 0)
 
-        function kernel(res, x, y, z)
-            idx = thread_position_in_grid_1d()
-            res[idx] = clamp(x[idx], y[idx], z[idx])
-            return nothing
+            function kernel(res, x, y, z)
+                idx = thread_position_in_grid_1d()
+                res[idx] = clamp(x[idx], y[idx], z[idx])
+                return nothing
+            end
+            Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
+            @test Array(mtlout) == clamp.(in, minval, maxval)
         end
-        Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
-        @test Array(mtlout) == clamp.(in, minval, maxval)
-    end
 
-    let #pow
-        N = 4
-        arr1 = rand(T, N)
-        arr2 = rand(T, N)
-        mtlarr1 = MtlArray(arr1)
-        mtlarr2 = MtlArray(arr2)
+        let #pow
+            N = 4
+            arr1 = rand(T, N)
+            arr2 = rand(T, N)
+            mtlarr1 = MtlArray(arr1)
+            mtlarr2 = MtlArray(arr2)
 
-        mtlout = fill!(similar(mtlarr1), 0)
+            mtlout = fill!(similar(mtlarr1), 0)
 
-        function kernel(res, x, y)
-            idx = thread_position_in_grid_1d()
-            res[idx] = x[idx]^y[idx]
-            return nothing
+            function kernel(res, x, y)
+                idx = thread_position_in_grid_1d()
+                res[idx] = x[idx]^y[idx]
+                return nothing
+            end
+            Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+            @test Array(mtlout) ≈ arr1 .^ arr2
         end
-        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
-        @test Array(mtlout) ≈ arr1 .^ arr2
-    end
 
-    let #powr
-        N = 4
-        arr1 = rand(T, N)
-        arr2 = rand(T, N)
-        mtlarr1 = MtlArray(arr1)
-        mtlarr2 = MtlArray(arr2)
+        let #powr
+            N = 4
+            arr1 = rand(T, N)
+            arr2 = rand(T, N)
+            mtlarr1 = MtlArray(arr1)
+            mtlarr2 = MtlArray(arr2)
 
-        mtlout = fill!(similar(mtlarr1), 0)
+            mtlout = fill!(similar(mtlarr1), 0)
 
-        function kernel(res, x, y)
-            idx = thread_position_in_grid_1d()
-            res[idx] = Metal.powr(x[idx], y[idx])
-            return nothing
+            function kernel(res, x, y)
+                idx = thread_position_in_grid_1d()
+                res[idx] = Metal.powr(x[idx], y[idx])
+                return nothing
+            end
+            Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+            @test Array(mtlout) ≈ arr1 .^ arr2
         end
-        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
-        @test Array(mtlout) ≈ arr1 .^ arr2
-    end
 
-    let # log1p
-        arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
-        buffer = MtlArray(arr)
-        vec = Array(log1p.(buffer))
-        @test vec ≈ log1p.(arr)
-    end
+        let # log1p
+            arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
+            buffer = MtlArray(arr)
+            vec = Array(log1p.(buffer))
+            @test vec ≈ log1p.(arr)
+        end
 
-    let # erf
-        arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
-        buffer = MtlArray(arr)
-        vec = Array(SpecialFunctions.erf.(buffer))
-        @test vec ≈ SpecialFunctions.erf.(arr)
-    end
+        let # erf
+            arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
+            buffer = MtlArray(arr)
+            vec = Array(SpecialFunctions.erf.(buffer))
+            @test vec ≈ SpecialFunctions.erf.(arr)
+        end
 
-    let # erfc
-        arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
-        buffer = MtlArray(arr)
-        vec = Array(SpecialFunctions.erfc.(buffer))
-        @test vec ≈ SpecialFunctions.erfc.(arr)
-    end
+        let # erfc
+            arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
+            buffer = MtlArray(arr)
+            vec = Array(SpecialFunctions.erfc.(buffer))
+            @test vec ≈ SpecialFunctions.erfc.(arr)
+        end
 
-    let # erfinv
-        arr = collect(LinRange(-1.0f0, 1.0f0, 20))
-        buffer = MtlArray(arr)
-        vec = Array(SpecialFunctions.erfinv.(buffer))
-        @test vec ≈ SpecialFunctions.erfinv.(arr)
-    end
+        let # erfinv
+            arr = collect(LinRange(-1.0f0, 1.0f0, 20))
+            buffer = MtlArray(arr)
+            vec = Array(SpecialFunctions.erfinv.(buffer))
+            @test vec ≈ SpecialFunctions.erfinv.(arr)
+        end
 
-    let # expm1
-        arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
-        buffer = MtlArray(arr)
-        vec = Array(expm1.(buffer))
-        @test vec ≈ expm1.(arr)
+        let # expm1
+            arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
+            buffer = MtlArray(arr)
+            vec = Array(expm1.(buffer))
+            @test vec ≈ expm1.(arr)
+        end
     end
 end
-end
 
 ############################################################################################
 

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Benchmark suite Current: 024b990 Previous: ca092c8 Ratio
private array/construct 24663.166666666668 ns 24829.916666666664 ns 0.99
private array/broadcast 460750 ns 458500 ns 1.00
private array/random/randn/Float32 799000 ns 798750 ns 1.00
private array/random/randn!/Float32 644458 ns 615041.5 ns 1.05
private array/random/rand!/Int64 573083 ns 563000 ns 1.02
private array/random/rand!/Float32 597167 ns 598021 ns 1.00
private array/random/rand/Int64 768125 ns 774083 ns 0.99
private array/random/rand/Float32 599250 ns 611583 ns 0.98
private array/copyto!/gpu_to_gpu 591396 ns 654250 ns 0.90
private array/copyto!/cpu_to_gpu 817042 ns 624208 ns 1.31
private array/copyto!/gpu_to_cpu 584041.5 ns 817708 ns 0.71
private array/accumulate/1d 1337813 ns 1329333 ns 1.01
private array/accumulate/2d 1390083 ns 1398375 ns 0.99
private array/iteration/findall/int 2066083.5 ns 2103583.5 ns 0.98
private array/iteration/findall/bool 1837604.5 ns 1824375 ns 1.01
private array/iteration/findfirst/int 1700958 ns 1688792 ns 1.01
private array/iteration/findfirst/bool 1668792 ns 1643000 ns 1.02
private array/iteration/scalar 3888833 ns 3772458 ns 1.03
private array/iteration/logical 3191625 ns 3187749.5 ns 1.00
private array/iteration/findmin/1d 1766208 ns 1760708 ns 1.00
private array/iteration/findmin/2d 1351583 ns 1344437.5 ns 1.01
private array/reductions/reduce/1d 1018542 ns 1031583 ns 0.99
private array/reductions/reduce/2d 663667 ns 654750 ns 1.01
private array/reductions/mapreduce/1d 1046750 ns 1033875 ns 1.01
private array/reductions/mapreduce/2d 657792 ns 659000 ns 1.00
private array/permutedims/4d 2536021 ns 2503500 ns 1.01
private array/permutedims/2d 1003562.5 ns 1028750 ns 0.98
private array/permutedims/3d 1579250 ns 1580708 ns 1.00
private array/copy 608791 ns 590270.5 ns 1.03
latency/precompile 8824701958 ns 8811389416 ns 1.00
latency/ttfp 3604612541 ns 3608628500 ns 1.00
latency/import 1232708333 ns 1231898292 ns 1.00
integration/metaldevrt 707708.5 ns 713792 ns 0.99
integration/byval/slices=1 1599020.5 ns 1617854.5 ns 0.99
integration/byval/slices=3 8769500 ns 9687812.5 ns 0.91
integration/byval/reference 1584937.5 ns 1589625 ns 1.00
integration/byval/slices=2 2688125 ns 2675542 ns 1.00
kernel/indexing 447541.5 ns 470792 ns 0.95
kernel/indexing_checked 458083 ns 463208 ns 0.99
kernel/launch 9805.333333333334 ns 9527.666666666666 ns 1.03
metal/synchronization/stream 14750 ns 15125 ns 0.98
metal/synchronization/context 15167 ns 14834 ns 1.02
shared array/construct 27840.333333333332 ns 24604.166666666668 ns 1.13
shared array/broadcast 462041 ns 461166 ns 1.00
shared array/random/randn/Float32 810250 ns 738958.5 ns 1.10
shared array/random/randn!/Float32 629750 ns 633292 ns 0.99
shared array/random/rand!/Int64 569854.5 ns 561625 ns 1.01
shared array/random/rand!/Float32 603958 ns 600416 ns 1.01
shared array/random/rand/Int64 777520.5 ns 778375 ns 1.00
shared array/random/rand/Float32 598833 ns 616000 ns 0.97
shared array/copyto!/gpu_to_gpu 79375 ns 79250 ns 1.00
shared array/copyto!/cpu_to_gpu 79666.5 ns 82084 ns 0.97
shared array/copyto!/gpu_to_cpu 79959 ns 82750 ns 0.97
shared array/accumulate/1d 1337458 ns 1335833 ns 1.00
shared array/accumulate/2d 1375541.5 ns 1388833 ns 0.99
shared array/iteration/findall/int 1819666.5 ns 1871833 ns 0.97
shared array/iteration/findall/bool 1554229.5 ns 1569500 ns 0.99
shared array/iteration/findfirst/int 1397958 ns 1396916 ns 1.00
shared array/iteration/findfirst/bool 1370833 ns 1367500 ns 1.00
shared array/iteration/scalar 158666 ns 154834 ns 1.02
shared array/iteration/logical 2981104 ns 2987020.5 ns 1.00
shared array/iteration/findmin/1d 1468459 ns 1477062.5 ns 0.99
shared array/iteration/findmin/2d 1366791.5 ns 1364708 ns 1.00
shared array/reductions/reduce/1d 723541.5 ns 731750 ns 0.99
shared array/reductions/reduce/2d 670166 ns 666250 ns 1.01
shared array/reductions/mapreduce/1d 743375 ns 736667 ns 1.01
shared array/reductions/mapreduce/2d 664542 ns 672459 ns 0.99
shared array/permutedims/4d 2543896 ns 2493333 ns 1.02
shared array/permutedims/2d 1021646 ns 1024646 ns 1.00
shared array/permutedims/3d 1576084 ns 1576667 ns 1.00
shared array/copy 244291.5 ns 244000 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.


@device_function trunc_fast(x::Float32) = ccall("extern air.fast_trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.trunc(x::Float16) = ccall("extern air.trunc.f16", llvmcall, Float16, (Float16,), x)

@static if Metal.is_macos(v"14")
Copy link
Member

@maleadt maleadt Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as on the other PR wrt. tying this to the macOS version.

If anything, we can punt on the intrinsic itself being version-specific and just ensuring the tests are. Maybe file an issue on GPUCompiler for compile-time error reporting from user code.

Copy link
Member Author

@christiangnrd christiangnrd Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. In that case if the rest is good I’ll take the nextafter commit out of this PR and rebase #529 on this PR so that we can merge the new tests and fixes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's been done so I'll merge once tests pass.

@christiangnrd christiangnrd force-pushed the intrinsics branch 2 times, most recently from ad73350 to 0a78ec7 Compare February 3, 2025 19:06
@christiangnrd christiangnrd force-pushed the intrinsics branch 2 times, most recently from 5c22476 to 404d23a Compare February 4, 2025 04:01
@christiangnrd christiangnrd merged commit b8ab3b6 into main Feb 4, 2025
6 of 7 checks passed
@christiangnrd christiangnrd deleted the intrinsics branch February 4, 2025 04:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants