Skip to content

Commit 0337e2f

Browse files
authored
support rename of CUDA internals (#122)
* support rename of internals
1 parent 1f093b5 commit 0337e2f

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1414

1515
[compat]
1616
Adapt = "0.4, 1.0, 2.0"
17-
CUDA = "~1.0, ~1.1, =1.2.0"
17+
CUDA = "~1.0, ~1.1, ~1.2"
1818
Cassette = "0.3.2"
1919
LLVM = "1.5"
2020
MacroTools = "0.5"

src/backends/cuda.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,17 @@ end
295295
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
296296
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
297297

298+
@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
299+
const emit_shmem = CUDA.emit_shmem
300+
else
301+
const emit_shmem = CUDA._shmem
302+
end
303+
298304
###
299305
# GPU implementation of shared memory
300306
###
301307
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
302-
ptr = CUDA._shmem(Val(Id), T, Val(prod(Dims)))
308+
ptr = emit_shmem(Val(Id), T, Val(prod(Dims)))
303309
CUDA.CuDeviceArray(Dims, ptr)
304310
end
305311

0 commit comments

Comments
 (0)