Skip to content

Commit a152355

Browse files
authored
Fix dynamic dispatch issues (#2235)
1 parent 30694a8 commit a152355

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

lib/cudadrv/execution.jl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,35 @@ being slightly faster.
127127
"""
128128
cudacall
129129

130-
# FIXME: can we make this infer properly?
131-
cudacall(f, types::Tuple, args...; kwargs...) =
132-
cudacall(f, Base.to_tuple_type(types), args...; kwargs...)
130+
cudacall(f::F, types::Tuple, args::Vararg{Any,N}; kwargs...) where {N,F} =
131+
cudacall(f, _to_tuple_type(types), args...; kwargs...)
132+
133+
function cudacall(f::F, types::Type{T}, args::Vararg{Any,N}; kwargs...) where {T,N,F}
134+
convert_arguments(
135+
((pointers::Vararg{Any,M},) where {M}) -> launch(f, pointers...; kwargs...),
136+
types,
137+
args...
138+
)
139+
end
133140

134-
function cudacall(f, types::Type, args...; kwargs...)
135-
convert_arguments(types, args...) do pointers...
136-
launch(f, pointers...; kwargs...)
141+
# From `julia/base/reflection.jl`, adjusted to add specialization on `t`.
142+
function _to_tuple_type(t)
143+
if isa(t, Tuple) || isa(t, AbstractArray) || isa(t, SimpleVector)
144+
t = Tuple{t...}
145+
end
146+
if isa(t, Type) && t <: Tuple
147+
for p in (Base.unwrap_unionall(t)::DataType).parameters
148+
if isa(p, Core.TypeofVararg)
149+
p = Base.unwrapva(p)
150+
end
151+
if !(isa(p, Type) || isa(p, TypeVar))
152+
error("argument tuple type must contain only types")
153+
end
154+
end
155+
else
156+
error("expected tuple type")
137157
end
158+
t
138159
end
139160

140161

lib/cudadrv/state.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ function context!(ctx::CuContext)
162162
return old_ctx
163163
end
164164

165-
@inline function context!(f::Function, ctx::CuContext; skip_destroyed::Bool=false)
165+
@inline function context!(f::F, ctx::CuContext; skip_destroyed::Bool=false) where {F<:Function}
166166
# @inline so that the kwarg method is inlined too and we can const-prop skip_destroyed
167167
if isvalid(ctx)
168168
old_ctx = context!(ctx)::Union{CuContext,Nothing}

src/compiler/execution.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ end
385385
# cache of kernel instances
386386
const _kernel_instances = Dict{Any, Any}()
387387

388-
function (kernel::HostKernel)(args...; threads::CuDim=1, blocks::CuDim=1, kwargs...)
388+
function (kernel::HostKernel)(args::Vararg{Any,N}; threads::CuDim=1, blocks::CuDim=1, kwargs...) where {N}
389389
call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...)
390390
end
391391

0 commit comments

Comments
 (0)