Skip to content

Commit af1b933

Browse files
Test and fix CUDA method replacement (#253)
Co-authored-by: Ali Ramadhan <ali.hh.ramadhan@gmail.com>
1 parent 1496f4e commit af1b933

File tree

4 files changed

+56
-13
lines changed

4 files changed

+56
-13
lines changed

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,12 @@ KernelAbstractions.generate_overdubs(@__MODULE__, CUDACtx)
282282
# CUDA specific method rewrites
283283
###
284284

285+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Float64) = ^(x, y)
286+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Float32) = ^(x, y)
287+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Int32) = ^(x, y)
288+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Int32) = ^(x, y)
289+
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Union{Float32, Float64}, y::Int64) = ^(x, y)
290+
285291
# libdevice.jl
286292
const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
287293
:acos, :asin, :atan,
@@ -300,12 +306,12 @@ for f in cudafuns
300306
end
301307
end
302308

303-
@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (CUDA.sin(x), CUDA.cos(x))
304-
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = CUDA.exp(x)
309+
@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (Base.sin(x), Base.cos(x))
310+
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = Base.exp(x)
305311

306312
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.gamma), x::Union{Float32, Float64}) = CUDA.tgamma(x)
307-
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
308-
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
313+
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = SpecialFunctions.erf(x)
314+
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = SpecialFunctions.erfc(x)
309315

310316
@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
311317
const emit_shmem = CUDA.emit_shmem

src/compiler.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ function generate_overdubs(mod, Ctx)
4848

4949
@inline Cassette.overdub(::$Ctx, ::typeof(Base.literal_pow), f::F, x, p) where F = Base.literal_pow(f, x, p)
5050

51+
@inline Cassette.overdub(::$Ctx, ::typeof(Base.throw_boundserror), args...) = Base.throw_boundserror(args...)
52+
@inline Cassette.overdub(::$Ctx, ::typeof(Base.Math.throw_exp_domainerror), args...) = Base.Math.throw_exp_domainerror(args...)
53+
5154
function Cassette.overdub(::$Ctx, ::typeof(:), start::T, step::T, stop::T) where T<:Union{Float16,Float32,Float64}
5255
lf = (stop-start)/step
5356
if lf < 0

test/compiler.jl

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,65 @@ end
1616
A[1] = B[1]^2
1717
end
1818

19+
@kernel function pow(A, B)
20+
A[1] = A[1]^B[1]
21+
end
22+
1923
@kernel function checked(A, a, b)
2024
A[1] = Base.Checked.checked_add(a, b)
2125
end
2226

23-
function compiler_testsuite()
27+
function check_for_overdub(stmt)
28+
if stmt isa Expr
29+
if stmt.head == :invoke
30+
mi = first(stmt.args)::Core.MethodInstance
31+
if mi.def.name === :overdub
32+
@show stmt
33+
return true
34+
end
35+
end
36+
end
37+
return false
38+
end
39+
40+
function compiler_testsuite(backend, ArrayT)
2441
kernel = index(CPU(), DynamicSize(), DynamicSize())
2542
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
2643
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(KernelAbstractions.NoDynamicCheck()))
2744
CTX = KernelAbstractions.cassette(kernel)
2845
@test KernelAbstractions.Cassette.overdub(CTX, KernelAbstractions.__index_Global_NTuple, ctx, CartesianIndex(1)) == (1,)
2946

30-
let (CI, rt) = @ka_code_typed literal_pow(CPU())(zeros(Int,1), ndrange=1)
47+
A = ArrayT{Int}(undef, 1)
48+
let (CI, rt) = @ka_code_typed literal_pow(backend())(A, ndrange=1)
49+
# test that there is no invoke of overdub
50+
@test !any(check_for_overdub, CI.code)
51+
end
52+
53+
A = ArrayT{Float64}(undef, 1)
54+
let (CI, rt) = @ka_code_typed square(backend())(A, A, ndrange=1)
55+
# test that there is no invoke of overdub
56+
@test !any(check_for_overdub, CI.code)
57+
end
58+
59+
A = ArrayT{Float64}(undef, 1)
60+
B = ArrayT{Float64}(undef, 1)
61+
let (CI, rt) = @ka_code_typed pow(backend())(A, B, ndrange=1)
3162
# test that there is no invoke of overdub
32-
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
63+
@test !any(check_for_overdub, CI.code)
3364
end
3465

35-
let (CI, rt) = @ka_code_typed square(CPU())(zeros(1), zeros(1), ndrange=1)
66+
A = ArrayT{Float64}(undef, 1)
67+
B = ArrayT{Int32}(undef, 1)
68+
let (CI, rt) = @ka_code_typed pow(backend())(A, B, ndrange=1)
3669
# test that there is no invoke of overdub
37-
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
70+
@test !any(check_for_overdub, CI.code)
3871
end
3972

4073
if VERSION >= v"1.5"
41-
let (CI, rt) = @ka_code_typed checked(CPU())(zeros(Int,1), 1, 2, ndrange=1)
74+
A = ArrayT{Int}(undef, 1)
75+
let (CI, rt) = @ka_code_typed checked(backend())(A, 1, 2, ndrange=1)
4276
# test that there is no invoke of overdub
43-
@test !any(stmt->(stmt isa Expr) && stmt.head == :invoke, CI.code)
77+
@test !any(check_for_overdub, CI.code)
4478
end
4579
end
4680
end

test/testsuite.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT)
5454
end
5555
end
5656

57-
if backend == CPU
57+
if backend_str != "ROCM"
5858
@testset "Compiler" begin
59-
compiler_testsuite()
59+
compiler_testsuite(backend, AT)
6060
end
6161
end
6262

0 commit comments

Comments
 (0)