Skip to content

Commit c6f35aa

Browse files
authored
Specialize Ref{<:Type} for GPU compatibility. (#1109)
1 parent 78379e1 commit c6f35aa

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/compiler/execution.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,15 @@ Adapt.adapt_storage(to::Adaptor, p::CuPtr{T}) where {T} = reinterpret(LLVMPtr{T,
123123
struct CuRefValue{T} <: Ref{T}
124124
x::T
125125
end
126-
Base.getindex(r::CuRefValue) = r.x
126+
Base.getindex(r::CuRefValue{T}) where T = r.x
127127
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = CuRefValue(adapt(to, r[]))
128128

129+
# broadcast sometimes passes a ref(type), resulting in a GPU-incompatible DataType box.
130+
# avoid that by using a special kind of ref that knows about the boxed type.
131+
struct CuRefType{T} <: Ref{DataType} end
132+
Base.getindex(r::CuRefType{T}) where T = T
133+
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue{<:Union{DataType,Type}}) = CuRefType{r[]}()
134+
129135
Adapt.adapt_storage(::Adaptor, xs::CuArray{T,N}) where {T,N} =
130136
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)
131137

test/broadcast.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
end
2323
@test Array(Whatever{Int}.(CuArray([1]))) == Whatever{Int}.([1])
2424
end
25+
2526
# https://github.com/JuliaGPU/CUDA.jl/issues/223
2627
@testset "Ref Broadcast" begin
2728
foobar(idx, A) = A[idx]
@@ -32,3 +33,9 @@ end
3233
@test testf(x -> log.(x), rand(3,3))
3334
@test testf((x,xs) -> log.(x.+xs), Ref(1), rand(3,3))
3435
end
36+
37+
# https://github.com/JuliaGPU/CUDA.jl/issues/261
38+
@testset "Broadcast Ref{<:Type}" begin
39+
A = CuArray{ComplexF64}(undef, (2,2))
40+
@test eltype(convert.(ComplexF32, A)) == ComplexF32
41+
end

0 commit comments

Comments
 (0)