Skip to content

Commit 8ed06bf

Browse files
authored
Metal: fix llvm.minimum lowering. (#602)
1 parent 28d96c1 commit 8ed06bf

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/metal.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,7 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct
945945
# IEEE 754-2018 compliant maximum/minimum, propagating NaNs and treating -0 as less than +0
946946
if intr == LLVM.Intrinsic("llvm.minimum") || intr == LLVM.Intrinsic("llvm.maximum")
947947
typ = value_type(call)
948+
is_minimum = intr == LLVM.Intrinsic("llvm.minimum")
948949

949950
# XXX: LLVM C API doesn't have getPrimitiveSizeInBits
950951
jltyp = if typ == LLVM.HalfType()
@@ -959,7 +960,12 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct
959960

960961
# create a function that performs the IEEE-compliant operation.
961962
# normally we'd do this inline, but LLVM.jl doesn't have BB split functionality.
962-
new_intr_fn = "air.minimum.f$(8*sizeof(jltyp))"
963+
new_intr_fn = if is_minimum
964+
"air.minimum.f$(8*sizeof(jltyp))"
965+
else
966+
"air.maximum.f$(8*sizeof(jltyp))"
967+
end
968+
963969
if haskey(functions(mod), new_intr_fn)
964970
new_intr = functions(mod)[new_intr_fn]
965971
else
@@ -1017,7 +1023,7 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct
10171023
position!(builder, bb_compare_zero)
10181024
arg0_negative = icmp!(builder, LLVM.API.LLVMIntNE, arg0_sign,
10191025
LLVM.ConstantInt(typ′, 0))
1020-
val = if intr == LLVM.Intrinsic("llvm.minimum")
1026+
val = if is_minimum
10211027
select!(builder, arg0_negative, arg0, arg1)
10221028
else
10231029
select!(builder, arg0_negative, arg1, arg0)
@@ -1027,7 +1033,7 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct
10271033
# finally, it's safe to use the existing minnum/maxnum intrinsics
10281034

10291035
position!(builder, bb_fallback)
1030-
fallback_intr_fn = if intr == LLVM.Intrinsic("llvm.minimum")
1036+
fallback_intr_fn = if is_minimum
10311037
"air.fmin.f$(8*sizeof(jltyp))"
10321038
else
10331039
"air.fmax.f$(8*sizeof(jltyp))"

0 commit comments

Comments
 (0)