@@ -282,6 +282,37 @@ KernelAbstractions.generate_overdubs(@__MODULE__, CUDACtx)
282
282
# CUDA specific method rewrites
283
283
# ##
284
284
285
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float64 , y:: Float64 ) = CUDA. pow (x, y)
286
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float32 , y:: Float32 ) = CUDA. pow (x, y)
287
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float64 , y:: Int32 ) = CUDA. pow (x, y)
288
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float32 , y:: Int32 ) = CUDA. pow (x, y)
289
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Union{Float32, Float64} , y:: Int64 ) = CUDA. pow (x, y)
290
+
291
+ # libdevice.jl
292
+ const cudafuns = (:cos , :cospi , :sin , :sinpi , :tan ,
293
+ :acos , :asin , :atan ,
294
+ :cosh , :sinh , :tanh ,
295
+ :acosh , :asinh , :atanh ,
296
+ :log , :log10 , :log1p , :log2 ,
297
+ :exp , :exp2 , :exp10 , :expm1 , :ldexp ,
298
+ # :isfinite, :isinf, :isnan, :signbit,
299
+ :abs ,
300
+ :sqrt , :cbrt ,
301
+ :ceil , :floor ,)
302
+ for f in cudafuns
303
+ @eval function Cassette. overdub (ctx:: CUDACtx , :: typeof (Base.$ f), x:: Union{Float32, Float64} )
304
+ @Base . _inline_meta
305
+ return CUDA.$ f (x)
306
+ end
307
+ end
308
+
309
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (sincos), x:: Union{Float32, Float64} ) = (CUDA. sin (x), CUDA. cos (x))
310
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (exp), x:: Union{ComplexF32, ComplexF64} ) = CUDA. exp (x)
311
+
312
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (SpecialFunctions. gamma), x:: Union{Float32, Float64} ) = CUDA. tgamma (x)
313
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (SpecialFunctions. erf), x:: Union{Float32, Float64} ) = CUDA. erf (x)
314
+ @inline Cassette. overdub (:: CUDACtx , :: typeof (SpecialFunctions. erfc), x:: Union{Float32, Float64} ) = CUDA. erfc (x)
315
+
285
316
@static if Base. isbindingresolved (CUDA, :emit_shmem ) && Base. isdefined (CUDA, :emit_shmem )
286
317
const emit_shmem = CUDA. emit_shmem
287
318
else
0 commit comments