Skip to content

Commit 83e38d8

Browse files
committed
Revert "Stop overdubbing CUDA math functions"
This reverts commit da035c1.
1 parent da035c1 commit 83e38d8

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,37 @@ KernelAbstractions.generate_overdubs(@__MODULE__, CUDACtx)
282282
# CUDA specific method rewrites
283283
###
284284

285+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Float64) = CUDA.pow(x, y)
286+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Float32) = CUDA.pow(x, y)
287+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Int32) = CUDA.pow(x, y)
288+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Int32) = CUDA.pow(x, y)
289+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Union{Float32, Float64}, y::Int64) = CUDA.pow(x, y)
290+
291+
# libdevice.jl
292+
const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
293+
:acos, :asin, :atan,
294+
:cosh, :sinh, :tanh,
295+
:acosh, :asinh, :atanh,
296+
:log, :log10, :log1p, :log2,
297+
:exp, :exp2, :exp10, :expm1, :ldexp,
298+
# :isfinite, :isinf, :isnan, :signbit,
299+
:abs,
300+
:sqrt, :cbrt,
301+
:ceil, :floor,)
302+
for f in cudafuns
303+
@eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64})
304+
@Base._inline_meta
305+
return CUDA.$f(x)
306+
end
307+
end
308+
309+
@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (CUDA.sin(x), CUDA.cos(x))
310+
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = CUDA.exp(x)
311+
312+
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.gamma), x::Union{Float32, Float64}) = CUDA.tgamma(x)
313+
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
314+
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
315+
285316
@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
286317
const emit_shmem = CUDA.emit_shmem
287318
else

0 commit comments

Comments
 (0)