Skip to content

Commit 4e875e7

Browse files
committed
adding atomics/x tests
1 parent d52a6f3 commit 4e875e7

File tree

8 files changed

+932
-0
lines changed

8 files changed

+932
-0
lines changed

lib/CUDAKernels/Manifest.toml

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
julia_version = "1.7.1"
4+
manifest_format = "2.0"
5+
6+
[[deps.AbstractFFTs]]
7+
deps = ["ChainRulesCore", "LinearAlgebra"]
8+
git-tree-sha1 = "6f1d9bc1c08f9f4a8fa92e3ea3cb50153a1b40d4"
9+
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
10+
version = "1.1.0"
11+
12+
[[deps.Adapt]]
13+
deps = ["LinearAlgebra"]
14+
git-tree-sha1 = "af92965fb30777147966f58acb05da51c5616b5f"
15+
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
16+
version = "3.3.3"
17+
18+
[[deps.ArgTools]]
19+
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
20+
21+
[[deps.Artifacts]]
22+
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
23+
24+
[[deps.Atomix]]
25+
deps = ["UnsafeAtomics"]
26+
git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be"
27+
uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
28+
version = "0.1.0"
29+
30+
[[deps.BFloat16s]]
31+
deps = ["LinearAlgebra", "Printf", "Random", "Test"]
32+
git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072"
33+
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
34+
version = "0.2.0"
35+
36+
[[deps.Base64]]
37+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
38+
39+
[[deps.CEnum]]
40+
git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90"
41+
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
42+
version = "0.4.2"
43+
44+
[[deps.CUDA]]
45+
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
46+
git-tree-sha1 = "19fb33957a5f85efb3cc10e70cf4dd4e30174ac9"
47+
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
48+
version = "3.10.0"
49+
50+
[[deps.ChainRulesCore]]
51+
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
52+
git-tree-sha1 = "9489214b993cd42d17f44c36e359bf6a7c919abf"
53+
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
54+
version = "1.15.0"
55+
56+
[[deps.ChangesOfVariables]]
57+
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
58+
git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53"
59+
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
60+
version = "0.1.3"
61+
62+
[[deps.Compat]]
63+
deps = ["Dates", "LinearAlgebra", "UUIDs"]
64+
git-tree-sha1 = "924cdca592bc16f14d2f7006754a621735280b74"
65+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
66+
version = "4.1.0"
67+
68+
[[deps.CompilerSupportLibraries_jll]]
69+
deps = ["Artifacts", "Libdl"]
70+
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
71+
72+
[[deps.Dates]]
73+
deps = ["Printf"]
74+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
75+
76+
[[deps.DocStringExtensions]]
77+
deps = ["LibGit2"]
78+
git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b"
79+
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
80+
version = "0.8.6"
81+
82+
[[deps.Downloads]]
83+
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
84+
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
85+
86+
[[deps.ExprTools]]
87+
git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d"
88+
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
89+
version = "0.1.8"
90+
91+
[[deps.GPUArrays]]
92+
deps = ["Adapt", "LLVM", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
93+
git-tree-sha1 = "c783e8883028bf26fb05ed4022c450ef44edd875"
94+
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
95+
version = "8.3.2"
96+
97+
[[deps.GPUCompiler]]
98+
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
99+
git-tree-sha1 = "d8c5999631e1dc18d767883f621639c838f8e632"
100+
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
101+
version = "0.15.2"
102+
103+
[[deps.InteractiveUtils]]
104+
deps = ["Markdown"]
105+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
106+
107+
[[deps.InverseFunctions]]
108+
deps = ["Test"]
109+
git-tree-sha1 = "336cc738f03e069ef2cac55a104eb823455dca75"
110+
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
111+
version = "0.1.4"
112+
113+
[[deps.IrrationalConstants]]
114+
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
115+
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
116+
version = "0.1.1"
117+
118+
[[deps.JLLWrappers]]
119+
deps = ["Preferences"]
120+
git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1"
121+
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
122+
version = "1.4.1"
123+
124+
[[deps.KernelAbstractions]]
125+
deps = ["Adapt", "InteractiveUtils", "LinearAlgebra", "MacroTools", "SparseArrays", "StaticArrays", "UUIDs"]
126+
git-tree-sha1 = "883ea9474c2a091dc6a698b525f47a651ae133b9"
127+
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
128+
version = "0.8.0"
129+
130+
[[deps.LLVM]]
131+
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
132+
git-tree-sha1 = "8c0b65f65ac27cf293c13089df78081b93790fa7"
133+
repo-rev = "master"
134+
repo-url = "https://github.com/maleadt/LLVM.jl.git"
135+
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
136+
version = "4.12.0"
137+
138+
[[deps.LLVMExtra_jll]]
139+
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
140+
git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576"
141+
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
142+
version = "0.0.16+0"
143+
144+
[[deps.LazyArtifacts]]
145+
deps = ["Artifacts", "Pkg"]
146+
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
147+
148+
[[deps.LibCURL]]
149+
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
150+
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
151+
152+
[[deps.LibCURL_jll]]
153+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
154+
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
155+
156+
[[deps.LibGit2]]
157+
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
158+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
159+
160+
[[deps.LibSSH2_jll]]
161+
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
162+
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
163+
164+
[[deps.Libdl]]
165+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
166+
167+
[[deps.LinearAlgebra]]
168+
deps = ["Libdl", "libblastrampoline_jll"]
169+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
170+
171+
[[deps.LogExpFunctions]]
172+
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
173+
git-tree-sha1 = "09e4b894ce6a976c354a69041a04748180d43637"
174+
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
175+
version = "0.3.15"
176+
177+
[[deps.Logging]]
178+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
179+
180+
[[deps.MacroTools]]
181+
deps = ["Markdown", "Random"]
182+
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf"
183+
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
184+
version = "0.5.9"
185+
186+
[[deps.Markdown]]
187+
deps = ["Base64"]
188+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
189+
190+
[[deps.MbedTLS_jll]]
191+
deps = ["Artifacts", "Libdl"]
192+
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
193+
194+
[[deps.MozillaCACerts_jll]]
195+
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
196+
197+
[[deps.NetworkOptions]]
198+
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
199+
200+
[[deps.OpenBLAS_jll]]
201+
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
202+
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
203+
204+
[[deps.OpenLibm_jll]]
205+
deps = ["Artifacts", "Libdl"]
206+
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
207+
208+
[[deps.OpenSpecFun_jll]]
209+
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
210+
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
211+
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
212+
version = "0.5.5+0"
213+
214+
[[deps.Pkg]]
215+
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
216+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
217+
218+
[[deps.Preferences]]
219+
deps = ["TOML"]
220+
git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d"
221+
uuid = "21216c6a-2e73-6563-6e65-726566657250"
222+
version = "1.3.0"
223+
224+
[[deps.Printf]]
225+
deps = ["Unicode"]
226+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
227+
228+
[[deps.REPL]]
229+
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
230+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
231+
232+
[[deps.Random]]
233+
deps = ["SHA", "Serialization"]
234+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
235+
236+
[[deps.Random123]]
237+
deps = ["Random", "RandomNumbers"]
238+
git-tree-sha1 = "afeacaecf4ed1649555a19cb2cad3c141bbc9474"
239+
uuid = "74087812-796a-5b5d-8853-05524746bad3"
240+
version = "1.5.0"
241+
242+
[[deps.RandomNumbers]]
243+
deps = ["Random", "Requires"]
244+
git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111"
245+
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
246+
version = "1.5.3"
247+
248+
[[deps.Reexport]]
249+
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
250+
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
251+
version = "1.2.2"
252+
253+
[[deps.Requires]]
254+
deps = ["UUIDs"]
255+
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
256+
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
257+
version = "1.3.0"
258+
259+
[[deps.SHA]]
260+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
261+
262+
[[deps.Serialization]]
263+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
264+
265+
[[deps.Sockets]]
266+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
267+
268+
[[deps.SparseArrays]]
269+
deps = ["LinearAlgebra", "Random"]
270+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
271+
272+
[[deps.SpecialFunctions]]
273+
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
274+
git-tree-sha1 = "bc40f042cfcc56230f781d92db71f0e21496dffd"
275+
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
276+
version = "2.1.5"
277+
278+
[[deps.StaticArrays]]
279+
deps = ["LinearAlgebra", "Random", "Statistics"]
280+
git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911"
281+
uuid = "90137ffa-7385-5640-81b9-e52037218182"
282+
version = "1.4.4"
283+
284+
[[deps.Statistics]]
285+
deps = ["LinearAlgebra", "SparseArrays"]
286+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
287+
288+
[[deps.TOML]]
289+
deps = ["Dates"]
290+
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
291+
292+
[[deps.Tar]]
293+
deps = ["ArgTools", "SHA"]
294+
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
295+
296+
[[deps.Test]]
297+
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
298+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
299+
300+
[[deps.TimerOutputs]]
301+
deps = ["ExprTools", "Printf"]
302+
git-tree-sha1 = "7638550aaea1c9a1e86817a231ef0faa9aca79bd"
303+
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
304+
version = "0.5.19"
305+
306+
[[deps.UUIDs]]
307+
deps = ["Random", "SHA"]
308+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
309+
310+
[[deps.Unicode]]
311+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
312+
313+
[[deps.UnsafeAtomics]]
314+
git-tree-sha1 = "2615625381ad7ea1a5b686e0c58676118c9a746a"
315+
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
316+
version = "0.2.0"
317+
318+
[[deps.UnsafeAtomicsLLVM]]
319+
deps = ["LLVM", "UnsafeAtomics"]
320+
git-tree-sha1 = "c63963c866ecf67ad34d70674cc8f1252f96d16c"
321+
repo-rev = "main"
322+
repo-url = "https://github.com/JuliaConcurrent/UnsafeAtomicsLLVM.jl"
323+
uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
324+
version = "0.1.0-DEV"
325+
326+
[[deps.Zlib_jll]]
327+
deps = ["Libdl"]
328+
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
329+
330+
[[deps.libblastrampoline_jll]]
331+
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"]
332+
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
333+
334+
[[deps.nghttp2_jll]]
335+
deps = ["Artifacts", "Libdl"]
336+
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
337+
338+
[[deps.p7zip_jll]]
339+
deps = ["Artifacts", "Libdl"]
340+
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ else
359359
end
360360

361361
import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize, __size
362+
import KernelAbstractions: atomic_add!, atomic_and!, atomic_cas!, atomic_dec!, atomic_inc!, atomic_max!, atomic_min!, atomic_op!, atomic_or!, atomic_sub!, atomic_xchg!, atomic_xor!
362363

363364
###
364365
# GPU implementation of shared memory
@@ -395,4 +396,29 @@ Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental
395396
# Argument conversion
396397
KernelAbstractions.argconvert(k::Kernel{CUDADevice}, arg) = CUDA.cudaconvert(arg)
397398

399+
400+
###
401+
# GPU implementation of atomics
402+
###
403+
404+
afxs = Dict(
405+
atomic_add! => CUDA.atomic_add!,
406+
atomic_and! => CUDA.atomic_and!,
407+
atomic_cas! => CUDA.atomic_cas!,
408+
atomic_dec! => CUDA.atomic_dec!,
409+
atomic_inc! => CUDA.atomic_inc!,
410+
atomic_max! => CUDA.atomic_max!,
411+
atomic_min! => CUDA.atomic_min!,
412+
atomic_op! => CUDA.atomic_op!,
413+
atomic_or! => CUDA.atomic_or!,
414+
atomic_sub! => CUDA.atomic_sub!,
415+
atomic_xchg! => CUDA.atomic_xchg!,
416+
atomic_xor! => CUDA.atomic_xor!
417+
)
418+
419+
for (afx, cfx) in afxs
420+
@device_override @inline function afx(args...)
421+
cfx(args...)
422+
end
423+
end
398424
end

src/KernelAbstractions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,10 @@ include("extras/extras.jl")
496496

497497
include("reflection.jl")
498498

499+
# Atomics
500+
501+
include("atomics.jl")
502+
499503
# CPU backend
500504

501505
include("cpu.jl")

0 commit comments

Comments
 (0)