Skip to content

Commit a580b95

Browse files
authored
Merge pull request #281 from JuliaGPU/vc/device_function
[CUDAKernels] Avoid Cassette looking at device_functions
2 parents 1fe5326 + def0249 commit a580b95

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,4 +356,29 @@ Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental
356356
# Argument conversion
357357
KernelAbstractions.argconvert(k::Kernel{CUDADevice}, arg) = CUDA.cudaconvert(arg)
358358

359+
# Cassette.jl#195
360+
# Device intrinsics are inferred in a different World (1.6) or using MethodOverlay tables (1.7)
361+
# Cassette sees neither of them and thus overdubbing them fails.
362+
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.arrayref), args...)
363+
CUDA.arrayref(args...)
364+
end
365+
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.arrayset), args...)
366+
CUDA.arrayset(args...)
367+
end
368+
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.const_arrayref), args...)
369+
CUDA.const_arrayref(args...)
370+
end
371+
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.logb), args...)
372+
CUDA.logb(args...)
373+
end
374+
# @inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.tgamma), args...)
375+
# CUDA.tgamma(args...)
376+
# end
377+
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.compute_capability), args...)
378+
CUDA.compute_capability(args...)
379+
end
380+
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.ptx_isa_version), args...)
381+
CUDA.ptx_isa_version(args...)
382+
end
383+
359384
end

0 commit comments

Comments
 (0)