Skip to content

Commit 460f005

Browse files
Merge #1708
1708: Add GPU Adaptor r=CarloLucibello a=DhairyaLGandhi Updates return types for CUDA/ CPU movement to hold structure information wherever possible. Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
2 parents d12ebfa + 03e48d0 commit 460f005

File tree

5 files changed

+130
-114
lines changed

5 files changed

+130
-114
lines changed

Manifest.toml

Lines changed: 68 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
2222

2323
[[ArrayInterface]]
2424
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
25-
git-tree-sha1 = "a71d224f61475b93c9e196e83c17c6ac4dedacfa"
25+
git-tree-sha1 = "019303a0f26d6012f35ecdfa4618551d145fb9f2"
2626
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
27-
version = "3.1.18"
27+
version = "3.1.31"
2828

2929
[[Artifacts]]
3030
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
@@ -44,22 +44,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
4444
version = "0.4.1"
4545

4646
[[CUDA]]
47-
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
48-
git-tree-sha1 = "5e696e37e51b01ae07bd9f700afe6cbd55250bce"
47+
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
48+
git-tree-sha1 = "335b3d2373733919b4972a51215a6840c7a33828"
4949
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
50-
version = "3.3.4"
50+
version = "3.4.2"
5151

5252
[[ChainRules]]
5353
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
54-
git-tree-sha1 = "0ff24ac6ea4f03d9ed5c90505c1e96273bf5f96d"
54+
git-tree-sha1 = "d88340ab502af66cfffc821e70ae72f7dbdce645"
5555
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
56-
version = "0.8.23"
56+
version = "1.11.5"
5757

5858
[[ChainRulesCore]]
5959
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
60-
git-tree-sha1 = "f53ca8d41e4753c41cdafa6ec5f7ce914b34be54"
60+
git-tree-sha1 = "30ee06de5ff870b45c78f529a6b093b3323256a3"
6161
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
62-
version = "0.10.13"
62+
version = "1.3.1"
6363

6464
[[CodecZlib]]
6565
deps = ["TranscodingStreams", "Zlib_jll"]
@@ -87,24 +87,24 @@ version = "0.3.0"
8787

8888
[[Compat]]
8989
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
90-
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
90+
git-tree-sha1 = "6071cb87be6a444ac75fdbf51b8e7273808ce62f"
9191
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
92-
version = "3.31.0"
92+
version = "3.35.0"
9393

9494
[[CompilerSupportLibraries_jll]]
9595
deps = ["Artifacts", "Libdl"]
9696
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
9797

9898
[[DataAPI]]
99-
git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385"
99+
git-tree-sha1 = "bec2532f8adb82005476c141ec23e921fc20971b"
100100
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
101-
version = "1.7.0"
101+
version = "1.8.0"
102102

103103
[[DataStructures]]
104104
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
105-
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677"
105+
git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02"
106106
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
107-
version = "0.18.9"
107+
version = "0.18.10"
108108

109109
[[Dates]]
110110
deps = ["Printf"]
@@ -122,9 +122,9 @@ version = "1.0.3"
122122

123123
[[DiffRules]]
124124
deps = ["NaNMath", "Random", "SpecialFunctions"]
125-
git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9"
125+
git-tree-sha1 = "7220bc21c33e990c14f4a9a319b1d242ebc5b269"
126126
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
127-
version = "1.0.2"
127+
version = "1.3.1"
128128

129129
[[Distributed]]
130130
deps = ["Random", "Serialization", "Sockets"]
@@ -146,10 +146,10 @@ uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
146146
version = "0.1.6"
147147

148148
[[FillArrays]]
149-
deps = ["LinearAlgebra", "Random", "SparseArrays"]
150-
git-tree-sha1 = "25b9cc23ba3303de0ad2eac03f840de9104c9253"
149+
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
150+
git-tree-sha1 = "caf289224e622f518c9dbfe832cdafa17d7c80a6"
151151
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
152-
version = "0.12.0"
152+
version = "0.12.4"
153153

154154
[[FixedPointNumbers]]
155155
deps = ["Statistics"]
@@ -159,27 +159,26 @@ version = "0.8.4"
159159

160160
[[ForwardDiff]]
161161
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
162-
git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6"
162+
git-tree-sha1 = "b5e930ac60b613ef3406da6d4f42c35d8dc51419"
163163
uuid = "f6369f11-7733-5829-9624-2563aa707210"
164-
version = "0.10.18"
164+
version = "0.10.19"
165165

166166
[[Functors]]
167-
deps = ["MacroTools"]
168-
git-tree-sha1 = "4cd9e70bf8fce05114598b663ad79dfe9ae432b3"
167+
git-tree-sha1 = "e2727f02325451f6b24445cd83bfa9aaac19cbe7"
169168
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
170-
version = "0.2.3"
169+
version = "0.2.5"
171170

172171
[[GPUArrays]]
173-
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
174-
git-tree-sha1 = "ececbf05f8904c92814bdbd0aafd5540b0bf2e9a"
172+
deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
173+
git-tree-sha1 = "8fac1cf7d6ce0f2249c7acaf25d22e1e85c4a07f"
175174
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
176-
version = "7.0.1"
175+
version = "8.0.2"
177176

178177
[[GPUCompiler]]
179-
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
180-
git-tree-sha1 = "e8a09182a4440489e2e3dedff5ad3f6bbe555396"
178+
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
179+
git-tree-sha1 = "4ed2616d5e656c8716736b64da86755467f26cf5"
181180
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
182-
version = "0.12.5"
181+
version = "0.12.9"
183182

184183
[[IRTools]]
185184
deps = ["InteractiveUtils", "MacroTools", "Test"]
@@ -196,6 +195,11 @@ version = "0.1.0"
196195
deps = ["Markdown"]
197196
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
198197

198+
[[IrrationalConstants]]
199+
git-tree-sha1 = "f76424439413893a832026ca355fe273e93bce94"
200+
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
201+
version = "0.1.0"
202+
199203
[[JLLWrappers]]
200204
deps = ["Preferences"]
201205
git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e"
@@ -210,15 +214,15 @@ version = "0.8.4"
210214

211215
[[LLVM]]
212216
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
213-
git-tree-sha1 = "733abcbdc67337bb6aaf873c6bebbe1e6440a5df"
217+
git-tree-sha1 = "29174613a9fa0424f5aef1a9dbd234acff7ce1f2"
214218
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
215-
version = "4.1.1"
219+
version = "4.5.0"
216220

217221
[[LLVMExtra_jll]]
218222
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
219-
git-tree-sha1 = "b36c0677a0549c7d1dc8719899a4133abbfacf7d"
223+
git-tree-sha1 = "9c360e5ce980b88bb31a7b086dbb19469008154b"
220224
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
221-
version = "0.0.6+0"
225+
version = "0.0.10+0"
222226

223227
[[LazyArtifacts]]
224228
deps = ["Artifacts", "Pkg"]
@@ -248,19 +252,19 @@ deps = ["Libdl"]
248252
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
249253

250254
[[LogExpFunctions]]
251-
deps = ["DocStringExtensions", "LinearAlgebra"]
252-
git-tree-sha1 = "7bd5f6565d80b6bf753738d2bc40a5dfea072070"
255+
deps = ["ChainRulesCore", "DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
256+
git-tree-sha1 = "86197a8ecb06e222d66797b0c2d2f0cc7b69e42b"
253257
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
254-
version = "0.2.5"
258+
version = "0.3.2"
255259

256260
[[Logging]]
257261
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
258262

259263
[[MacroTools]]
260264
deps = ["Markdown", "Random"]
261-
git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
265+
git-tree-sha1 = "0fb723cd8c45858c22169b2e42269e53271a6df7"
262266
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
263-
version = "0.5.6"
267+
version = "0.5.7"
264268

265269
[[Markdown]]
266270
deps = ["Base64"]
@@ -278,9 +282,9 @@ version = "0.5.0"
278282

279283
[[Missings]]
280284
deps = ["DataAPI"]
281-
git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7"
285+
git-tree-sha1 = "2ca267b08821e86c5ef4376cffed98a46c2cb205"
282286
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
283-
version = "1.0.0"
287+
version = "1.0.1"
284288

285289
[[Mmap]]
286290
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
@@ -290,15 +294,15 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
290294

291295
[[NNlib]]
292296
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
293-
git-tree-sha1 = "3de64e776a467311c907f5a767ee8a022a8a2f76"
297+
git-tree-sha1 = "5203a4532ad28c44f82c76634ad621d7c90abcbd"
294298
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
295-
version = "0.7.25"
299+
version = "0.7.29"
296300

297301
[[NNlibCUDA]]
298302
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
299-
git-tree-sha1 = "a7de026dc0ff9f47551a16ad9a710da66881b953"
303+
git-tree-sha1 = "04490d5e7570c038b1cb0f5c3627597181cc15a9"
300304
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
301-
version = "0.1.7"
305+
version = "0.1.9"
302306

303307
[[NaNMath]]
304308
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
@@ -353,14 +357,14 @@ version = "1.4.2"
353357

354358
[[RandomNumbers]]
355359
deps = ["Random", "Requires"]
356-
git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f"
360+
git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111"
357361
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
358-
version = "1.4.0"
362+
version = "1.5.3"
359363

360364
[[Reexport]]
361-
git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220"
365+
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
362366
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
363-
version = "1.1.0"
367+
version = "1.2.2"
364368

365369
[[Requires]]
366370
deps = ["UUIDs"]
@@ -393,21 +397,21 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
393397

394398
[[SpecialFunctions]]
395399
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
396-
git-tree-sha1 = "a50550fa3164a8c46747e62063b4d774ac1bcf49"
400+
git-tree-sha1 = "a322a9493e49c5f3a10b50df3aedaf1cdb3244b7"
397401
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
398-
version = "1.5.1"
402+
version = "1.6.1"
399403

400404
[[Static]]
401405
deps = ["IfElse"]
402-
git-tree-sha1 = "62701892d172a2fa41a1f829f66d2b0db94a9a63"
406+
git-tree-sha1 = "854b024a4a81b05c0792a4b45293b85db228bd27"
403407
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
404-
version = "0.3.0"
408+
version = "0.3.1"
405409

406410
[[StaticArrays]]
407411
deps = ["LinearAlgebra", "Random", "Statistics"]
408-
git-tree-sha1 = "1b9a0f17ee0adde9e538227de093467348992397"
412+
git-tree-sha1 = "3240808c6d463ac46f1c1cd7638375cd22abbccb"
409413
uuid = "90137ffa-7385-5640-81b9-e52037218182"
410-
version = "1.2.7"
414+
version = "1.2.12"
411415

412416
[[Statistics]]
413417
deps = ["LinearAlgebra", "SparseArrays"]
@@ -420,9 +424,9 @@ version = "1.0.0"
420424

421425
[[StatsBase]]
422426
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
423-
git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d"
427+
git-tree-sha1 = "8cbbc098554648c84f79a463c9ff0fd277144b6c"
424428
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
425-
version = "0.33.8"
429+
version = "0.33.10"
426430

427431
[[TOML]]
428432
deps = ["Dates"]
@@ -444,9 +448,9 @@ version = "0.5.12"
444448

445449
[[TranscodingStreams]]
446450
deps = ["Random", "Test"]
447-
git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
451+
git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c"
448452
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
449-
version = "0.9.5"
453+
version = "0.9.6"
450454

451455
[[UUIDs]]
452456
deps = ["Random", "SHA"]
@@ -467,9 +471,11 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
467471

468472
[[Zygote]]
469473
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
470-
git-tree-sha1 = "8b634fdb4c3c63f2ceaa2559a008da4f405af6b3"
474+
git-tree-sha1 = "aa382565aaf48b5c8873be393f3b9e45dc7f703c"
475+
repo-rev = "master"
476+
repo-url = "https://github.com/FluxML/Zygote.jl.git"
471477
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
472-
version = "0.6.17"
478+
version = "0.6.21"
473479

474480
[[ZygoteRules]]
475481
deps = ["MacroTools"]

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2121
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2222
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2323
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
24+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2425
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2526
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2627
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

docs/src/models/basics.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ julia> f(x) = 3x^2 + 2x + 1;
1212
julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2
1313
1414
julia> df(2)
15-
14
15+
14.0
1616
1717
julia> d2f(x) = gradient(df, x)[1]; # d²f/dx² = 6
1818
1919
julia> d2f(2)
20-
6
20+
6.0
2121
```
2222

2323
When a function has many parameters, we can get gradients of each one at the same time:

0 commit comments

Comments
 (0)