Skip to content

Commit 15969e9

Browse files
authored
Allow broadcast errors to surface later. (#491)
We are currently checking for `isconcretetype`, which is too limiting, as e.g. CUDA.jl supports isbits-union arrays. Checking for `allocatedinline` would be better, but let's just do away with the check entirely and have the array constructor fail in the presence of unsupported element types. This is better anyway, as some back-ends may not support isbits-unions. We do however still check for Union{}, as that isn't allocated inline so would fail array construction. By using Nothing there instead, we give the GPU kernel (which is expected to throw an error) the chance to execute and report an exeception dynamically. This also makes it possible to trace a broadcast invocation under, e.g., Cthulhu.
1 parent b04d64a commit 15969e9

File tree

11 files changed

+191
-224
lines changed

11 files changed

+191
-224
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
version: ['1.6', '1.7', '1.8', '1.9', 'nightly']
17+
version: ['1.8', '1.9', '1.10.0-beta2', 'nightly']
1818
os: [ubuntu-latest, macOS-latest, windows-latest]
1919
arch: [x64]
2020
steps:
@@ -50,7 +50,7 @@ jobs:
5050
- uses: actions/checkout@v3
5151
- uses: julia-actions/setup-julia@latest
5252
with:
53-
version: '1.6'
53+
version: '1.8'
5454
- name: Install dependencies
5555
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
5656
- name: Build and deploy

Manifest.toml

Lines changed: 81 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,178 +1,213 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
[[Adapt]]
3+
manifest_format = "2.0"
4+
5+
[[deps.Adapt]]
46
deps = ["LinearAlgebra", "Requires"]
57
git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24"
68
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
79
version = "3.6.2"
810

9-
[[ArgTools]]
11+
[[deps.ArgTools]]
1012
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
13+
version = "1.1.1"
1114

12-
[[Artifacts]]
15+
[[deps.Artifacts]]
1316
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
1417

15-
[[Base64]]
18+
[[deps.Base64]]
1619
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1720

18-
[[CEnum]]
21+
[[deps.CEnum]]
1922
git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90"
2023
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
2124
version = "0.4.2"
2225

23-
[[Dates]]
26+
[[deps.CompilerSupportLibraries_jll]]
27+
deps = ["Artifacts", "Libdl"]
28+
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
29+
version = "1.0.1+0"
30+
31+
[[deps.Dates]]
2432
deps = ["Printf"]
2533
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
2634

27-
[[Downloads]]
28-
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
35+
[[deps.Downloads]]
36+
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
2937
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
38+
version = "1.6.0"
39+
40+
[[deps.FileWatching]]
41+
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
3042

31-
[[GPUArraysCore]]
43+
[[deps.GPUArraysCore]]
3244
deps = ["Adapt"]
3345
path = "lib/GPUArraysCore"
3446
uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
3547
version = "0.1.5"
3648

37-
[[InteractiveUtils]]
49+
[[deps.InteractiveUtils]]
3850
deps = ["Markdown"]
3951
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
4052

41-
[[JLLWrappers]]
53+
[[deps.JLLWrappers]]
4254
deps = ["Artifacts", "Preferences"]
4355
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
4456
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
4557
version = "1.5.0"
4658

47-
[[LLVM]]
59+
[[deps.LLVM]]
4860
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
4961
git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729"
5062
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
5163
version = "6.1.0"
5264

53-
[[LLVMExtra_jll]]
65+
[[deps.LLVMExtra_jll]]
5466
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
5567
git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c"
5668
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
5769
version = "0.0.23+0"
5870

59-
[[LazyArtifacts]]
71+
[[deps.LazyArtifacts]]
6072
deps = ["Artifacts", "Pkg"]
6173
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
6274

63-
[[LibCURL]]
75+
[[deps.LibCURL]]
6476
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
6577
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
78+
version = "0.6.3"
6679

67-
[[LibCURL_jll]]
80+
[[deps.LibCURL_jll]]
6881
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
6982
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
83+
version = "7.84.0+0"
7084

71-
[[LibGit2]]
85+
[[deps.LibGit2]]
7286
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
7387
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
7488

75-
[[LibSSH2_jll]]
89+
[[deps.LibSSH2_jll]]
7690
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
7791
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
92+
version = "1.10.2+0"
7893

79-
[[Libdl]]
94+
[[deps.Libdl]]
8095
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
8196

82-
[[LinearAlgebra]]
83-
deps = ["Libdl"]
97+
[[deps.LinearAlgebra]]
98+
deps = ["Libdl", "libblastrampoline_jll"]
8499
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
85100

86-
[[Logging]]
101+
[[deps.Logging]]
87102
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
88103

89-
[[Markdown]]
104+
[[deps.Markdown]]
90105
deps = ["Base64"]
91106
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
92107

93-
[[MbedTLS_jll]]
108+
[[deps.MbedTLS_jll]]
94109
deps = ["Artifacts", "Libdl"]
95110
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
111+
version = "2.28.0+0"
96112

97-
[[MozillaCACerts_jll]]
113+
[[deps.MozillaCACerts_jll]]
98114
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
115+
version = "2022.2.1"
99116

100-
[[NetworkOptions]]
117+
[[deps.NetworkOptions]]
101118
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
119+
version = "1.2.0"
120+
121+
[[deps.OpenBLAS_jll]]
122+
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
123+
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
124+
version = "0.3.20+0"
102125

103-
[[Pkg]]
126+
[[deps.Pkg]]
104127
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
105128
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
129+
version = "1.8.0"
106130

107-
[[Preferences]]
131+
[[deps.Preferences]]
108132
deps = ["TOML"]
109133
git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1"
110134
uuid = "21216c6a-2e73-6563-6e65-726566657250"
111135
version = "1.4.0"
112136

113-
[[Printf]]
137+
[[deps.Printf]]
114138
deps = ["Unicode"]
115139
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
116140

117-
[[REPL]]
141+
[[deps.REPL]]
118142
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
119143
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
120144

121-
[[Random]]
122-
deps = ["Serialization"]
145+
[[deps.Random]]
146+
deps = ["SHA", "Serialization"]
123147
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
124148

125-
[[Reexport]]
149+
[[deps.Reexport]]
126150
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
127151
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
128152
version = "1.2.2"
129153

130-
[[Requires]]
154+
[[deps.Requires]]
131155
deps = ["UUIDs"]
132156
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
133157
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
134158
version = "1.3.0"
135159

136-
[[SHA]]
160+
[[deps.SHA]]
137161
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
162+
version = "0.7.0"
138163

139-
[[Serialization]]
164+
[[deps.Serialization]]
140165
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
141166

142-
[[Sockets]]
167+
[[deps.Sockets]]
143168
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
144169

145-
[[SparseArrays]]
170+
[[deps.SparseArrays]]
146171
deps = ["LinearAlgebra", "Random"]
147172
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
148173

149-
[[Statistics]]
174+
[[deps.Statistics]]
150175
deps = ["LinearAlgebra", "SparseArrays"]
151176
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
152177

153-
[[TOML]]
178+
[[deps.TOML]]
154179
deps = ["Dates"]
155180
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
181+
version = "1.0.0"
156182

157-
[[Tar]]
183+
[[deps.Tar]]
158184
deps = ["ArgTools", "SHA"]
159185
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
186+
version = "1.10.1"
160187

161-
[[UUIDs]]
188+
[[deps.UUIDs]]
162189
deps = ["Random", "SHA"]
163190
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
164191

165-
[[Unicode]]
192+
[[deps.Unicode]]
166193
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
167194

168-
[[Zlib_jll]]
195+
[[deps.Zlib_jll]]
169196
deps = ["Libdl"]
170197
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
198+
version = "1.2.12+3"
199+
200+
[[deps.libblastrampoline_jll]]
201+
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"]
202+
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
203+
version = "5.1.1+0"
171204

172-
[[nghttp2_jll]]
205+
[[deps.nghttp2_jll]]
173206
deps = ["Artifacts", "Libdl"]
174207
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
208+
version = "1.48.0+0"
175209

176-
[[p7zip_jll]]
210+
[[deps.p7zip_jll]]
177211
deps = ["Artifacts", "Libdl"]
178212
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
213+
version = "17.4.0+0"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ Adapt = "2.0, 3.0"
1818
GPUArraysCore = "= 0.1.5"
1919
LLVM = "3.9, 4, 5, 6"
2020
Reexport = "1"
21-
julia = "1.6"
21+
julia = "1.8"

lib/JLArrays/src/JLArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ struct JLArray{T, N} <: AbstractGPUArray{T, N}
144144
dims::Dims{N}
145145

146146
function JLArray{T,N}(data::Array{T, N}, dims::Dims{N}) where {T,N}
147-
@assert isbitstype(T) "JLArray only supports bits types"
147+
isbitstype(T) || error("JLArray only supports bits types")
148+
# when supporting isbits-union types, use `Base.allocatedinline` here.
148149
new(data, dims)
149150
end
150151
end

src/host/broadcast.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ backend(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = backend(AT)
2121
# but make sure we don't dispatch to the optimized copy method that directly indexes
2222
function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
2323
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
24-
isbitstype(ElType) || error("Cannot broadcast function returning non-isbits $ElType.")
24+
if ElType == Union{}
25+
# using a Union{} eltype would fail early, during GPU array construction,
26+
# so use Nothing instead to give the error a chance to be thrown dynamically.
27+
ElType = Nothing
28+
end
2529
dest = copyto!(similar(bc, ElType), bc)
2630
return @allowscalar dest[CartesianIndex()] # 0D broadcast needs to unwrap results
2731
end
@@ -30,9 +34,10 @@ end
3034
# iteration (see, e.g., CUDA.jl#145)
3135
@inline function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle})
3236
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
33-
if !Base.isconcretetype(ElType)
34-
error("""GPU broadcast resulted in non-concrete element type $ElType.
35-
This probably means that the function you are broadcasting contains an error or type instability.""")
37+
if ElType == Union{}
38+
# using a Union{} eltype would fail early, during GPU array construction,
39+
# so use Nothing instead to give the error a chance to be thrown dynamically.
40+
ElType = Nothing
3641
end
3742
copyto!(similar(bc, ElType), bc)
3843
end

src/host/indexing.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,6 @@ function findminmax(binop, xs::AnyGPUArray; init, dims)
156156
return t1
157157
end
158158

159-
@static if VERSION < v"1.7.0-DEV.119"
160-
# before JuliaLang/julia#35316, isless/isgreated did not order NaNs last
161-
function reduction(t1::Tuple{<:AbstractFloat,<:Any}, t2::Tuple{<:AbstractFloat,<:Any})
162-
(x, i), (y, j) = t1, t2
163-
164-
isnan(x) && return t1
165-
isnan(y) && return t2
166-
167-
binop(x, y) && return t2
168-
x == y && return (x, min(i, j))
169-
return t1
170-
end
171-
end
172-
173159
if dims == Colon()
174160
res = mapreduce(tuple, reduction, xs, indices; init = (init, dummy_index))
175161

0 commit comments

Comments
 (0)