Skip to content

Commit 1496f4e

Browse files
Merge #249
249: Stop overdubbing CUDA math functions r=vchuravy a=ali-ramadhan I think with CUDA.jl v3 there is no longer a need to overdub/rewrite CUDA specific math functions. If anything `CUDA.pow` doesn't exist anymore. Co-authored-by: Ali Ramadhan <ali.hh.ramadhan@gmail.com> Co-authored-by: ali.hh.ramadhan@gmail.com <ali.hh.ramadhan@gmail.com>
2 parents fdb7415 + 22ee38e commit 1496f4e

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,6 @@ KernelAbstractions.generate_overdubs(@__MODULE__, CUDACtx)
282282
# CUDA specific method rewrites
283283
###
284284

285-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Float64) = CUDA.pow(x, y)
286-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Float32) = CUDA.pow(x, y)
287-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Int32) = CUDA.pow(x, y)
288-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Int32) = CUDA.pow(x, y)
289-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Union{Float32, Float64}, y::Int64) = CUDA.pow(x, y)
290-
291285
# libdevice.jl
292286
const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
293287
:acos, :asin, :atan,
@@ -302,7 +296,7 @@ const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
302296
for f in cudafuns
303297
@eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64})
304298
@Base._inline_meta
305-
return CUDA.$f(x)
299+
return Base.$f(x)
306300
end
307301
end
308302

test/compiler.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ end
1212
A[1] = 2^11
1313
end
1414

15+
@kernel function square(A, B)
16+
A[1] = B[1]^2
17+
end
18+
1519
@kernel function checked(A, a, b)
1620
A[1] = Base.Checked.checked_add(a, b)
1721
end
@@ -28,6 +32,11 @@ function compiler_testsuite()
2832
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
2933
end
3034

35+
let (CI, rt) = @ka_code_typed square(CPU())(zeros(1), zeros(1), ndrange=1)
36+
# test that there is no invoke of overdub
37+
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
38+
end
39+
3140
if VERSION >= v"1.5"
3241
let (CI, rt) = @ka_code_typed checked(CPU())(zeros(Int,1), 1, 2, ndrange=1)
3342
# test that there is no invoke of overdub

0 commit comments

Comments
 (0)