@@ -103,17 +103,38 @@ end
103
103
104
104
@device_override Base. log (x:: Float64 ) = ccall (" extern __nv_log" , llvmcall, Cdouble, (Cdouble,), x)
105
105
@device_override Base. log (x:: Float32 ) = ccall (" extern __nv_logf" , llvmcall, Cfloat, (Cfloat,), x)
106
+ @device_override function Base. log (x:: Float16 )
107
+ if compute_capability () >= sv " 8.0"
108
+ ccall (" extern __nv_hlog" , llvmcall, Float16, (Float16,), x)
109
+ else
110
+ return Float16 (log (Float32 (x)))
111
+ end
112
+ end
106
113
@device_override FastMath. log_fast (x:: Float32 ) = ccall (" extern __nv_fast_logf" , llvmcall, Cfloat, (Cfloat,), x)
107
114
108
115
@device_override Base. log10 (x:: Float64 ) = ccall (" extern __nv_log10" , llvmcall, Cdouble, (Cdouble,), x)
109
116
@device_override Base. log10 (x:: Float32 ) = ccall (" extern __nv_log10f" , llvmcall, Cfloat, (Cfloat,), x)
117
+ @device_override function Base. log10 (x:: Float16 )
118
+ if compute_capability () >= sv " 8.0"
119
+ ccall (" extern __nv_hlog10" , llvmcall, Float16, (Float16,), x)
120
+ else
121
+ return Float16 (log10 (Float32 (x)))
122
+ end
123
+ end
110
124
@device_override FastMath. log10_fast (x:: Float32 ) = ccall (" extern __nv_fast_log10f" , llvmcall, Cfloat, (Cfloat,), x)
111
125
112
126
@device_override Base. log1p (x:: Float64 ) = ccall (" extern __nv_log1p" , llvmcall, Cdouble, (Cdouble,), x)
113
127
@device_override Base. log1p (x:: Float32 ) = ccall (" extern __nv_log1pf" , llvmcall, Cfloat, (Cfloat,), x)
114
128
115
129
@device_override Base. log2 (x:: Float64 ) = ccall (" extern __nv_log2" , llvmcall, Cdouble, (Cdouble,), x)
116
130
@device_override Base. log2 (x:: Float32 ) = ccall (" extern __nv_log2f" , llvmcall, Cfloat, (Cfloat,), x)
131
+ @device_override function Base. log2 (x:: Float16 )
132
+ if compute_capability () >= sv " 8.0"
133
+ ccall (" extern __nv_hlog2" , llvmcall, Float16, (Float16,), x)
134
+ else
135
+ return Float16 (log (Float32 (x)))
136
+ end
137
+ end
117
138
@device_override FastMath. log2_fast (x:: Float32 ) = ccall (" extern __nv_fast_log2f" , llvmcall, Cfloat, (Cfloat,), x)
118
139
119
140
@device_function logb (x:: Float64 ) = ccall (" extern __nv_logb" , llvmcall, Cdouble, (Cdouble,), x)
@@ -127,16 +148,35 @@ end
127
148
128
149
@device_override Base. exp (x:: Float64 ) = ccall (" extern __nv_exp" , llvmcall, Cdouble, (Cdouble,), x)
129
150
@device_override Base. exp (x:: Float32 ) = ccall (" extern __nv_expf" , llvmcall, Cfloat, (Cfloat,), x)
151
+ @device_override function Base. exp (x:: Float16 )
152
+ if compute_capability () >= sv " 8.0"
153
+ ccall (" extern __nv_hexp" , llvmcall, Float16, (Float16,), x)
154
+ else
155
+ return Float16 (exp (Float32 (x)))
156
+ end
157
+ end
130
158
@device_override FastMath. exp_fast (x:: Float32 ) = ccall (" extern __nv_fast_expf" , llvmcall, Cfloat, (Cfloat,), x)
131
159
132
160
@device_override Base. exp2 (x:: Float64 ) = ccall (" extern __nv_exp2" , llvmcall, Cdouble, (Cdouble,), x)
133
161
@device_override Base. exp2 (x:: Float32 ) = ccall (" extern __nv_exp2f" , llvmcall, Cfloat, (Cfloat,), x)
162
+ @device_override function Base. exp2 (x:: Float16 )
163
+ if compute_capability () >= sv " 8.0"
164
+ ccall (" extern __nv_hexp2" , llvmcall, Float16, (Float16,), x)
165
+ else
166
+ return Float16 (exp2 (Float32 (x)))
167
+ end
168
+ end
134
169
@device_override FastMath. exp2_fast (x:: Union{Float32, Float64} ) = exp2 (x)
135
- # TODO : enable once PTX > 7.0 is supported
136
- # @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
137
170
138
171
@device_override Base. exp10 (x:: Float64 ) = ccall (" extern __nv_exp10" , llvmcall, Cdouble, (Cdouble,), x)
139
172
@device_override Base. exp10 (x:: Float32 ) = ccall (" extern __nv_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
173
+ @device_override function Base. exp10 (x:: Float16 )
174
+ if compute_capability () >= sv " 8.0"
175
+ ccall (" extern __nv_hexp10" , llvmcall, Float16, (Float16,), x)
176
+ else
177
+ return Float16 (exp10 (Float32 (x)))
178
+ end
179
+ end
140
180
@device_override FastMath. exp10_fast (x:: Float32 ) = ccall (" extern __nv_fast_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
141
181
142
182
@device_override Base. expm1 (x:: Float64 ) = ccall (" extern __nv_expm1" , llvmcall, Cdouble, (Cdouble,), x)
204
244
205
245
@device_override Base. isnan (x:: Float64 ) = (ccall (" extern __nv_isnand" , llvmcall, Int32, (Cdouble,), x)) != 0
206
246
@device_override Base. isnan (x:: Float32 ) = (ccall (" extern __nv_isnanf" , llvmcall, Int32, (Cfloat,), x)) != 0
247
+ @device_override function Base. isnan (x:: Float16 )
248
+ if compute_capability () >= sv " 8.0"
249
+ return (ccall (" extern __nv_hisnan" , llvmcall, Int32, (Float16,), x)) != 0
250
+ else
251
+ return isnan (Float32 (x))
252
+ end
253
+ end
207
254
208
255
@device_function nearbyint (x:: Float64 ) = ccall (" extern __nv_nearbyint" , llvmcall, Cdouble, (Cdouble,), x)
209
256
@device_function nearbyint (x:: Float32 ) = ccall (" extern __nv_nearbyintf" , llvmcall, Cfloat, (Cfloat,), x)
@@ -223,14 +270,26 @@ end
223
270
@device_override Base. abs (x:: Int32 ) = ccall (" extern __nv_abs" , llvmcall, Int32, (Int32,), x)
224
271
@device_override Base. abs (f:: Float64 ) = ccall (" extern __nv_fabs" , llvmcall, Cdouble, (Cdouble,), f)
225
272
@device_override Base. abs (f:: Float32 ) = ccall (" extern __nv_fabsf" , llvmcall, Cfloat, (Cfloat,), f)
226
- # TODO : enable once PTX > 7.0 is supported
227
- # @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
273
+ @device_override function Base. abs (f:: Float16 )
274
+ if compute_capability () >= sv " 8.0"
275
+ ccall (" extern __nv_habs" , llvmcall, Float16, (Float16,), f)
276
+ else
277
+ return Float16 (abs (Float32 (f)))
278
+ end
279
+ end
228
280
@device_override Base. abs (x:: Int64 ) = ccall (" extern __nv_llabs" , llvmcall, Int64, (Int64,), x)
229
281
230
282
# # roots and powers
231
283
232
284
@device_override Base. sqrt (x:: Float64 ) = ccall (" extern __nv_sqrt" , llvmcall, Cdouble, (Cdouble,), x)
233
285
@device_override Base. sqrt (x:: Float32 ) = ccall (" extern __nv_sqrtf" , llvmcall, Cfloat, (Cfloat,), x)
286
+ @device_override function Base. sqrt (x:: Float16 )
287
+ if compute_capability () >= sv " 8.0"
288
+ ccall (" extern __nv_hsqrt" , llvmcall, Float16, (Float16,), x)
289
+ else
290
+ return Float16 (sqrt (Float32 (x)))
291
+ end
292
+ end
234
293
@device_override FastMath. sqrt_fast (x:: Union{Float32, Float64} ) = sqrt (x)
235
294
236
295
@device_function rsqrt (x:: Float64 ) = ccall (" extern __nv_rsqrt" , llvmcall, Cdouble, (Cdouble,), x)
295
354
# JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
296
355
# @device_override Base.min(x::Float64, y::Float64) = ccall("extern __nv_fmin", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
297
356
# @device_override Base.min(x::Float32, y::Float32) = ccall("extern __nv_fminf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
357
+ @device_override @inline function Base. min (x:: Float16 , y:: Float16 )
358
+ if compute_capability () >= sv " 8.0"
359
+ return ccall (" extern __nv_hmin" , llvmcall, Float16, (Float16, Float16), x, y)
360
+ else
361
+ return Float16 (min (Float32 (x), Float32 (y)))
362
+ end
363
+ end
298
364
@device_override @inline function Base. min (x:: Float32 , y:: Float32 )
299
365
if @static LLVM. version () < v " 14" ? false : (compute_capability () >= sv " 8.0" )
300
366
# LLVM 14+ can do the right thing, but only on sm_80+
321
387
# JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
322
388
# @device_override Base.max(x::Float64, y::Float64) = ccall("extern __nv_fmax", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
323
389
# @device_override Base.max(x::Float32, y::Float32) = ccall("extern __nv_fmaxf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
390
+ @device_override @inline function Base. max (x:: Float16 , y:: Float16 )
391
+ if compute_capability () >= sv " 8.0"
392
+ return ccall (" extern __nv_hmax" , llvmcall, Float16, (Float16, Float16), x, y)
393
+ else
394
+ return Float16 (max (Float32 (x), Float32 (y)))
395
+ end
396
+ end
324
397
@device_override @inline function Base. max (x:: Float32 , y:: Float32 )
325
398
if @static LLVM. version () < v " 14" ? false : (compute_capability () >= sv " 8.0" )
326
399
# LLVM 14+ can do the right thing, but only on sm_80+
0 commit comments