Skip to content

Commit bb6e7fa

Browse files
committed
Working code
1 parent f4b4a58 commit bb6e7fa

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

src/device/intrinsics/math.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,14 @@ end
298298
if metal_version() >= sv"3.1" # macOS 14+
299299
ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
300300
else
301-
reinterpret(Float32, reinterpret(UInt32, x) + sign(y-x))
301+
nextfloat(x, unsafe_trunc(Int32, sign(y - x)))
302302
end
303303
end
304304
@device_function function nextafter(x::Float16, y::Float16)
305305
if metal_version() >= sv"3.1" # macOS 14+
306306
ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y)
307307
else
308-
reinterpret(Float16, reinterpret(UInt16, x) + sign(y-x))
308+
nextfloat(x, unsafe_trunc(Int16, sign(y - x)))
309309
end
310310
end
311311

test/device/intrinsics.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -355,21 +355,19 @@ end
355355

356356

357357
let # nextafter
358-
if Metal.is_macos(v"14")
359-
N = 4
360-
function nextafter_test(X, y)
361-
idx = thread_position_in_grid_1d()
362-
X[idx] = Metal.nextafter(X[idx], y)
363-
return nothing
364-
end
365-
arr = rand(T, N)
366-
buffer = MtlArray(arr)
367-
Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T))
368-
@test Array(buffer) == nextfloat.(arr)
369-
370-
Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T))
371-
@test Array(buffer) == arr
358+
N = 4
359+
function nextafter_test(X, y)
360+
idx = thread_position_in_grid_1d()
361+
X[idx] = Metal.nextafter(X[idx], y)
362+
return nothing
372363
end
364+
arr = rand(T, N)
365+
buffer = MtlArray(arr)
366+
Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T))
367+
@test Array(buffer) == nextfloat.(arr)
368+
369+
Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T))
370+
@test Array(buffer) == arr
373371
end
374372
end
375373
end

0 commit comments

Comments
 (0)