Skip to content

Commit 7970d56

Browse files
authored
Make sure broadcast and map errors are visible to the user. (#524)
1 parent e4d40ea commit 7970d56

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

src/host/broadcast.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,20 @@ end
2020
# iteration (see, e.g., CUDA.jl#145)
2121
@inline function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle})
2222
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
23-
if ElType == Union{}
24-
# using a Union{} eltype would fail early, during GPU array construction,
25-
# so use Nothing instead to give the error a chance to be thrown dynamically.
26-
ElType = Nothing
23+
if ElType == Union{} || !Base.allocatedinline(ElType)
24+
# a Union{} or non-isbits eltype would fail early, during GPU array construction,
25+
# so use a special marker to give the error a chance to be thrown during compilation
26+
# or even dynamically, and pick that marker up afterwards to throw an error.
27+
ElType = BrokenBroadcast{ElType}
2728
end
2829
copyto!(similar(bc, ElType), bc)
2930
end
3031

32+
struct BrokenBroadcast{T} end
33+
Base.convert(::Type{BrokenBroadcast{T}}, x) where T = BrokenBroadcast{T}()
34+
Base.convert(::Type{BrokenBroadcast{T}}, x::BrokenBroadcast{T}) where T = x
35+
Base.eltype(::Type{BrokenBroadcast{T}}) where T = T
36+
3137
@inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle}
3238
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
3339
end
@@ -76,9 +82,14 @@ end
7682
gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread;
7783
threads=config.threads, blocks=config.blocks)
7884

85+
if eltype(dest) <: BrokenBroadcast
86+
throw(ArgumentError("Broadcast operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))
87+
end
88+
7989
return dest
8090
end
8191

92+
8293
## map
8394

8495
allequal(x) = true
@@ -97,7 +108,10 @@ function Base.map(f, x::AnyGPUArray, xs::AbstractArray...)
97108

98109
# construct a broadcast to figure out the destination container
99110
ElType = Broadcast.combine_eltypes(f, xs)
100-
isbitstype(ElType) || error("Cannot map function returning non-isbits $ElType.")
111+
if ElType == Union{} || !Base.allocatedinline(ElType)
112+
# see `broadcast`
113+
ElType = BrokenBroadcast{ElType}
114+
end
101115
dest = similar(x, ElType, common_length)
102116

103117
return map!(f, dest, xs...)
@@ -138,5 +152,9 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
138152
gpu_call(map_kernel, dest, bc, config.elements_per_thread;
139153
threads=config.threads, blocks=config.blocks)
140154

155+
if eltype(dest) <: BrokenBroadcast
156+
throw(ArgumentError("Map operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))
157+
end
158+
141159
return dest
142160
end

0 commit comments

Comments
 (0)