Skip to content

Commit da035c1

Browse files
authored
Stop overdubbing CUDA math functions
1 parent fdb7415 commit da035c1

File tree

1 file changed

+0
-31
lines changed

1 file changed

+0
-31
lines changed

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -282,37 +282,6 @@ KernelAbstractions.generate_overdubs(@__MODULE__, CUDACtx)
282282
# CUDA specific method rewrites
283283
###
284284

285-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Float64) = CUDA.pow(x, y)
286-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Float32) = CUDA.pow(x, y)
287-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Int32) = CUDA.pow(x, y)
288-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Int32) = CUDA.pow(x, y)
289-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Union{Float32, Float64}, y::Int64) = CUDA.pow(x, y)
290-
291-
# libdevice.jl
292-
const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
293-
:acos, :asin, :atan,
294-
:cosh, :sinh, :tanh,
295-
:acosh, :asinh, :atanh,
296-
:log, :log10, :log1p, :log2,
297-
:exp, :exp2, :exp10, :expm1, :ldexp,
298-
# :isfinite, :isinf, :isnan, :signbit,
299-
:abs,
300-
:sqrt, :cbrt,
301-
:ceil, :floor,)
302-
for f in cudafuns
303-
@eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64})
304-
@Base._inline_meta
305-
return CUDA.$f(x)
306-
end
307-
end
308-
309-
@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (CUDA.sin(x), CUDA.cos(x))
310-
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = CUDA.exp(x)
311-
312-
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.gamma), x::Union{Float32, Float64}) = CUDA.tgamma(x)
313-
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
314-
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
315-
316285
@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
317286
const emit_shmem = CUDA.emit_shmem
318287
else

0 commit comments

Comments
 (0)