Skip to content

Commit 5c1db19

Browse files
committed
Wrap and test some more Float16 intrinsics
1 parent 4d85f27 commit 5c1db19

File tree

2 files changed

+98
-10
lines changed

2 files changed

+98
-10
lines changed

src/device/intrinsics/math.jl

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,38 @@ end
103103

104104
@device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x)
105105
@device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x)
106+
@device_override function Base.log(x::Float16)
107+
if compute_capability() >= sv"8.0"
108+
ccall("extern __nv_hlog", llvmcall, Float16, (Float16,), x)
109+
else
110+
return Float16(log(Float32(x)))
111+
end
112+
end
106113
@device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x)
107114

108115
@device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x)
109116
@device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x)
117+
@device_override function Base.log10(x::Float16)
118+
if compute_capability() >= sv"8.0"
119+
ccall("extern __nv_hlog10", llvmcall, Float16, (Float16,), x)
120+
else
121+
return Float16(log10(Float32(x)))
122+
end
123+
end
110124
@device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x)
111125

112126
@device_override Base.log1p(x::Float64) = ccall("extern __nv_log1p", llvmcall, Cdouble, (Cdouble,), x)
113127
@device_override Base.log1p(x::Float32) = ccall("extern __nv_log1pf", llvmcall, Cfloat, (Cfloat,), x)
114128

115129
@device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x)
116130
@device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x)
131+
@device_override function Base.log2(x::Float16)
132+
if compute_capability() >= sv"8.0"
133+
ccall("extern __nv_hlog2", llvmcall, Float16, (Float16,), x)
134+
else
135+
return Float16(log(Float32(x)))
136+
end
137+
end
117138
@device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x)
118139

119140
@device_function logb(x::Float64) = ccall("extern __nv_logb", llvmcall, Cdouble, (Cdouble,), x)
@@ -127,16 +148,35 @@ end
127148

128149
@device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x)
129150
@device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x)
151+
@device_override function Base.exp(x::Float16)
152+
if compute_capability() >= sv"8.0"
153+
ccall("extern __nv_hexp", llvmcall, Float16, (Float16,), x)
154+
else
155+
return Float16(exp(Float32(x)))
156+
end
157+
end
130158
@device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x)
131159

132160
@device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x)
133161
@device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x)
162+
@device_override function Base.exp2(x::Float16)
163+
if compute_capability() >= sv"8.0"
164+
ccall("extern __nv_hexp2", llvmcall, Float16, (Float16,), x)
165+
else
166+
return Float16(exp2(Float32(x)))
167+
end
168+
end
134169
@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x)
135-
# TODO: enable once PTX > 7.0 is supported
136-
# @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
137170

138171
@device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x)
139172
@device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x)
173+
@device_override function Base.exp10(x::Float16)
174+
if compute_capability() >= sv"8.0"
175+
ccall("extern __nv_hexp10", llvmcall, Float16, (Float16,), x)
176+
else
177+
return Float16(exp10(Float32(x)))
178+
end
179+
end
140180
@device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x)
141181

142182
@device_override Base.expm1(x::Float64) = ccall("extern __nv_expm1", llvmcall, Cdouble, (Cdouble,), x)
@@ -204,6 +244,13 @@ end
204244

205245
@device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0
206246
@device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0
247+
@device_override function Base.isnan(x::Float16)
248+
if compute_capability() >= sv"8.0"
249+
return (ccall("extern __nv_hisnan", llvmcall, Int32, (Float16,), x)) != 0
250+
else
251+
return isnan(Float32(x))
252+
end
253+
end
207254

208255
@device_function nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x)
209256
@device_function nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x)
@@ -223,14 +270,26 @@ end
223270
@device_override Base.abs(x::Int32) = ccall("extern __nv_abs", llvmcall, Int32, (Int32,), x)
224271
@device_override Base.abs(f::Float64) = ccall("extern __nv_fabs", llvmcall, Cdouble, (Cdouble,), f)
225272
@device_override Base.abs(f::Float32) = ccall("extern __nv_fabsf", llvmcall, Cfloat, (Cfloat,), f)
226-
# TODO: enable once PTX > 7.0 is supported
227-
# @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
273+
@device_override function Base.abs(f::Float16)
274+
if compute_capability() >= sv"8.0"
275+
ccall("extern __nv_habs", llvmcall, Float16, (Float16,), f)
276+
else
277+
return Float16(abs(Float32(f)))
278+
end
279+
end
228280
@device_override Base.abs(x::Int64) = ccall("extern __nv_llabs", llvmcall, Int64, (Int64,), x)
229281

230282
## roots and powers
231283

232284
@device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x)
233285
@device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x)
286+
@device_override function Base.sqrt(x::Float16)
287+
if compute_capability() >= sv"8.0"
288+
ccall("extern __nv_hsqrt", llvmcall, Float16, (Float16,), x)
289+
else
290+
return Float16(sqrt(Float32(x)))
291+
end
292+
end
234293
@device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x)
235294

236295
@device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x)
@@ -295,6 +354,13 @@ end
295354
# JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
296355
#@device_override Base.min(x::Float64, y::Float64) = ccall("extern __nv_fmin", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
297356
#@device_override Base.min(x::Float32, y::Float32) = ccall("extern __nv_fminf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
357+
@device_override @inline function Base.min(x::Float16, y::Float16)
358+
if compute_capability() >= sv"8.0"
359+
return ccall("extern __nv_hmin", llvmcall, Float16, (Float16, Float16), x, y)
360+
else
361+
return Float16(min(Float32(x), Float32(y)))
362+
end
363+
end
298364
@device_override @inline function Base.min(x::Float32, y::Float32)
299365
if @static LLVM.version() < v"14" ? false : (compute_capability() >= sv"8.0")
300366
# LLVM 14+ can do the right thing, but only on sm_80+
@@ -321,6 +387,13 @@ end
321387
# JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
322388
#@device_override Base.max(x::Float64, y::Float64) = ccall("extern __nv_fmax", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
323389
#@device_override Base.max(x::Float32, y::Float32) = ccall("extern __nv_fmaxf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
390+
@device_override @inline function Base.max(x::Float16, y::Float16)
391+
if compute_capability() >= sv"8.0"
392+
return ccall("extern __nv_hmax", llvmcall, Float16, (Float16, Float16), x, y)
393+
else
394+
return Float16(max(Float32(x), Float32(y)))
395+
end
396+
end
324397
@device_override @inline function Base.max(x::Float32, y::Float32)
325398
if @static LLVM.version() < v"14" ? false : (compute_capability() >= sv"8.0")
326399
# LLVM 14+ can do the right thing, but only on sm_80+

test/core/device/intrinsics/math.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ using SpecialFunctions
22

33
@testset "math" begin
44
@testset "log10" begin
5-
@test testf(a->log10.(a), Float32[100])
5+
for T in (Float16, Float32, Float64)
6+
@test testf(a->log10.(a), T[100])
7+
end
68
end
79

810
@testset "pow" begin
@@ -12,15 +14,22 @@ using SpecialFunctions
1214
@test testf((x,y)->x.^y, rand(Float32, 1), -rand(range, 1))
1315
end
1416
end
17+
18+
@testset "min/max" begin
19+
for T in (Float16, Float32, Float64)
20+
@test testf((x,y)->max.(x, y), rand(Float32, 1), rand(T, 1))
21+
@test testf((x,y)->min.(x, y), rand(Float32, 1), rand(T, 1))
22+
end
23+
end
1524

1625
@testset "isinf" begin
17-
for x in (Inf32, Inf, NaN32, NaN)
26+
for x in (Inf32, Inf, NaN16, NaN32, NaN)
1827
@test testf(x->isinf.(x), [x])
1928
end
2029
end
2130

2231
@testset "isnan" begin
23-
for x in (Inf32, Inf, NaN32, NaN)
32+
for x in (Inf32, Inf, NaN16, NaN32, NaN)
2433
@test testf(x->isnan.(x), [x])
2534
end
2635
end
@@ -33,7 +42,6 @@ using SpecialFunctions
3342
end
3443
end
3544
end
36-
3745
for op in (expm1,)
3846
@testset "$op" begin
3947
# FIXME: add expm1(::Float16) to Base
@@ -50,7 +58,6 @@ using SpecialFunctions
5058
@test testf(x->op.(x), rand(T, 1))
5159
@test testf(x->op.(x), -rand(T, 1))
5260
end
53-
5461
end
5562
end
5663
@testset "mod and rem" begin
@@ -97,6 +104,14 @@ using SpecialFunctions
97104
# JuliaGPU/CUDA.jl#1085: exp uses Base.sincos performing a global CPU load
98105
@test testf(x->exp.(x), [1e7im])
99106
end
107+
108+
for op in (exp, abs, abs2, log, exp10, log10)
109+
@testset "Real - $op" begin
110+
for T in (Float16, Float32, Float64)
111+
@test testf(x->op.(x), rand(T, 1))
112+
end
113+
end
114+
end
100115

101116
@testset "fastmath" begin
102117
# libdevice provides some fast math functions
@@ -150,7 +165,7 @@ using SpecialFunctions
150165
end
151166

152167
@testset "JuliaGPU/CUDA.jl#2111: min/max should return NaN" begin
153-
for T in [Float32, Float64]
168+
for T in [Float16, Float32, Float64]
154169
AT = CuArray{T}
155170
@test isequal(Array(min.(AT([NaN]), AT([Inf]))), [NaN])
156171
@test isequal(minimum(AT([NaN])), NaN)

0 commit comments

Comments
 (0)