Skip to content

Commit 8dfe4fa

Browse files
Merge #1456
1456: Update Zygote version to 0.6 r=CarloLucibello a=CarloLucibello Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
2 parents ebd37d6 + 3da23fa commit 8dfe4fa

File tree

4 files changed

+42
-93
lines changed

4 files changed

+42
-93
lines changed

Manifest.toml

Lines changed: 21 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,9 @@ version = "0.3.3"
1414

1515
[[Adapt]]
1616
deps = ["LinearAlgebra"]
17-
git-tree-sha1 = "42c42f2221906892ceb765dbcb1a51deeffd86d7"
17+
git-tree-sha1 = "27edd95a09fd428113ca019c092e8aeca2eb1f2d"
1818
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
19-
version = "2.3.0"
20-
21-
[[ArrayInterface]]
22-
deps = ["LinearAlgebra", "Requires", "SparseArrays"]
23-
git-tree-sha1 = "3b5bd474a90bee86b50f26268bbb044bb4d9ef83"
24-
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
25-
version = "2.14.9"
26-
27-
[[ArrayLayouts]]
28-
deps = ["Compat", "FillArrays", "LinearAlgebra", "SparseArrays"]
29-
git-tree-sha1 = "a577e27915fdcb3f6b96118b56655b38e3b466f2"
30-
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
31-
version = "0.4.12"
19+
version = "3.0.0"
3220

3321
[[Artifacts]]
3422
deps = ["Pkg"]
@@ -51,16 +39,16 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
5139
version = "0.4.1"
5240

5341
[[CUDA]]
54-
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"]
55-
git-tree-sha1 = "7663b61782b569b03fba91d330a5ed2f86cd4cb8"
42+
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"]
43+
git-tree-sha1 = "39f6f584bec264ace76f924d1c8637c85617697e"
5644
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
57-
version = "2.3.0"
45+
version = "2.4.0"
5846

5947
[[ChainRules]]
6048
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"]
61-
git-tree-sha1 = "93a956cf20a439fe6147d6fb3cda07816afea411"
49+
git-tree-sha1 = "31b28f5123afa5e5ca0c885e4051896032754578"
6250
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
63-
version = "0.7.41"
51+
version = "0.7.45"
6452

6553
[[ChainRulesCore]]
6654
deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"]
@@ -82,9 +70,9 @@ version = "0.10.9"
8270

8371
[[Colors]]
8472
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"]
85-
git-tree-sha1 = "30b2dd71d1585435c905e3228ca878867eb57e4b"
73+
git-tree-sha1 = "ac5f2213e56ed8a34a3dd2f681f4df1166b34929"
8674
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
87-
version = "0.12.5"
75+
version = "0.12.6"
8876

8977
[[CommonSubexpressions]]
9078
deps = ["MacroTools", "Test"]
@@ -139,12 +127,6 @@ version = "1.0.2"
139127
deps = ["Random", "Serialization", "Sockets"]
140128
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
141129

142-
[[DocStringExtensions]]
143-
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
144-
git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1"
145-
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
146-
version = "0.8.3"
147-
148130
[[ExprTools]]
149131
git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e"
150132
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -176,47 +158,30 @@ version = "0.1.0"
176158

177159
[[GPUArrays]]
178160
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
179-
git-tree-sha1 = "2c1dd57bca7ba0b3b4bf81d9332aeb81b154ef4c"
161+
git-tree-sha1 = "f99a25fe0313121f2f9627002734c7d63b4dd3bd"
180162
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
181-
version = "6.1.2"
163+
version = "6.2.0"
182164

183165
[[GPUCompiler]]
184166
deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
185167
git-tree-sha1 = "c853c810b52a80f9aad79ab109207889e57f41ef"
186168
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
187169
version = "0.8.3"
188170

189-
[[Hwloc]]
190-
deps = ["Hwloc_jll"]
191-
git-tree-sha1 = "2f32147444692235ad4ccc5e03e2d8e9a6b5d247"
192-
uuid = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
193-
version = "1.1.0"
194-
195-
[[Hwloc_jll]]
196-
deps = ["Libdl", "Pkg"]
197-
git-tree-sha1 = "d9de29482e0a9efb0639328e208d02e01554fa9b"
198-
uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8"
199-
version = "2.2.0+0"
200-
201171
[[IRTools]]
202172
deps = ["InteractiveUtils", "MacroTools", "Test"]
203173
git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510"
204174
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
205175
version = "0.4.2"
206176

207-
[[IfElse]]
208-
git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
209-
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
210-
version = "0.1.0"
211-
212177
[[InteractiveUtils]]
213178
deps = ["Markdown"]
214179
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
215180

216181
[[JLLWrappers]]
217-
git-tree-sha1 = "c70593677bbf2c3ccab4f7500d0f4dacfff7b75c"
182+
git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0"
218183
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
219-
version = "1.1.3"
184+
version = "1.2.0"
220185

221186
[[Juno]]
222187
deps = ["Base64", "Logging", "Media", "Profile"]
@@ -226,9 +191,9 @@ version = "0.8.4"
226191

227192
[[LLVM]]
228193
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
229-
git-tree-sha1 = "a2101830a761d592b113129887fda626387f68d4"
194+
git-tree-sha1 = "d0d99629d6ae4a3e211ae83d8870907bd842c811"
230195
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
231-
version = "3.5.1"
196+
version = "3.5.2"
232197

233198
[[LibGit2]]
234199
deps = ["Printf"]
@@ -244,12 +209,6 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
244209
[[Logging]]
245210
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
246211

247-
[[LoopVectorization]]
248-
deps = ["ArrayInterface", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "SLEEFPirates", "UnPack", "VectorizationBase"]
249-
git-tree-sha1 = "3066adba33448098ba12ac8d7dbd4835210b81f2"
250-
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
251-
version = "0.9.12"
252-
253212
[[MacroTools]]
254213
deps = ["Markdown", "Random"]
255214
git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
@@ -281,22 +240,16 @@ uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
281240
version = "0.2.2"
282241

283242
[[NNlib]]
284-
deps = ["Compat", "Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
285-
git-tree-sha1 = "2b7c3213ed4f2eed686f9f531f85d3ea2f75286f"
243+
deps = ["ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
244+
git-tree-sha1 = "13fd29731c7f609cb82a3a544c5538584d22c153"
286245
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
287-
version = "0.7.9"
246+
version = "0.7.11"
288247

289248
[[NaNMath]]
290249
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
291250
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
292251
version = "0.3.5"
293252

294-
[[OffsetArrays]]
295-
deps = ["Adapt"]
296-
git-tree-sha1 = "b0cc1c42b63e30b759f4e1cf045ad8a51069d6cc"
297-
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
298-
version = "1.4.2"
299-
300253
[[OpenSpecFun_jll]]
301254
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
302255
git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
@@ -343,12 +296,6 @@ version = "1.1.2"
343296
[[SHA]]
344297
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
345298

346-
[[SLEEFPirates]]
347-
deps = ["IfElse", "Libdl", "VectorizationBase"]
348-
git-tree-sha1 = "d82dffab8f9e50d5110c5650c25fdf9e00dec316"
349-
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
350-
version = "0.6.1"
351-
352299
[[Scratch]]
353300
deps = ["Dates"]
354301
git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
@@ -417,20 +364,9 @@ version = "0.9.5"
417364
deps = ["Random", "SHA"]
418365
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
419366

420-
[[UnPack]]
421-
git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b"
422-
uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
423-
version = "1.0.2"
424-
425367
[[Unicode]]
426368
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
427369

428-
[[VectorizationBase]]
429-
deps = ["ArrayInterface", "Hwloc", "IfElse", "Libdl", "LinearAlgebra"]
430-
git-tree-sha1 = "9c3cf92a81ec2d85f87939ed27707a6600c936e7"
431-
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
432-
version = "0.14.9"
433-
434370
[[ZipFile]]
435371
deps = ["Libdl", "Printf", "Zlib_jll"]
436372
git-tree-sha1 = "c3a5637e27e914a7a445b8d0ad063d701931e9f7"
@@ -444,10 +380,10 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
444380
version = "1.2.11+18"
445381

446382
[[Zygote]]
447-
deps = ["AbstractFFTs", "ArrayLayouts", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
448-
git-tree-sha1 = "18f758f28ca2c236e449be64e366e201965129a7"
383+
deps = ["AbstractFFTs", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
384+
git-tree-sha1 = "d88a7e71fc2eef9510187b1c7d4af7a5177633d0"
449385
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
450-
version = "0.5.17"
386+
version = "0.6.0"
451387

452388
[[ZygoteRules]]
453389
deps = ["MacroTools"]

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2727

2828
[compat]
2929
AbstractTrees = "0.3"
30-
Adapt = "2.0"
30+
Adapt = "2.0, 3.0"
3131
CUDA = "2.1"
3232
CodecZlib = "0.7"
3333
Colors = "0.12"
3434
Functors = "0.1"
3535
Juno = "0.8"
3636
MacroTools = "0.5"
37-
NNlib = "0.7"
37+
NNlib = "0.7.10"
3838
Reexport = "0.2, 1.0"
3939
StatsBase = "0.33"
4040
ZipFile = "0.9"
41-
Zygote = "0.5"
41+
Zygote = "0.6"
4242
julia = "1.5"
4343

4444
[extras]

src/layers/recurrent.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ reset!(m) = foreach(reset!, functor(m)[1])
5656
# TODO remove in v0.13
5757
function Base.getproperty(m::Recur, sym::Symbol)
5858
if sym === :init
59-
@warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead."
59+
Zygote.ignore() do
60+
@warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead."
61+
end
6062
return getfield(m.cell, :state0)
6163
else
6264
return getfield(m, sym)
@@ -104,7 +106,9 @@ Recur(m::RNNCell) = Recur(m, m.state0)
104106
# TODO remove in v0.13
105107
function Base.getproperty(m::RNNCell, sym::Symbol)
106108
if sym === :h
107-
@warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead."
109+
Zygote.ignore() do
110+
@warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead."
111+
end
108112
return getfield(m, :state0)
109113
else
110114
return getfield(m, sym)
@@ -163,11 +167,15 @@ Recur(m::LSTMCell) = Recur(m, m.state0)
163167
# TODO remove in v0.13
164168
function Base.getproperty(m::LSTMCell, sym::Symbol)
165169
if sym === :h
166-
@warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead."
170+
Zygote.ignore() do
171+
@warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead."
172+
end
167173
return getfield(m, :state0)[1]
168174
elseif sym === :c
175+
Zygote.ignore() do
169176
@warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
170-
return getfield(m, :state0)[2]
177+
end
178+
return getfield(m, :state0)[2]
171179
else
172180
return getfield(m, sym)
173181
end
@@ -215,7 +223,9 @@ Recur(m::GRUCell) = Recur(m, m.state0)
215223
# TODO remove in v0.13
216224
function Base.getproperty(m::GRUCell, sym::Symbol)
217225
if sym === :h
218-
@warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead."
226+
Zygote.ignore() do
227+
@warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead."
228+
end
219229
return getfield(m, :state0)
220230
else
221231
return getfield(m, sym)

src/outputsize.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module NilNumber
22

33
using NNlib
4+
import Random
45

56
"""
67
Nil <: Number
@@ -40,6 +41,8 @@ Base.typemax(::Type{Nil}) = nil
4041

4142
Base.promote_rule(x::Type{Nil}, y::Type{<:Number}) = Nil
4243

44+
Random.rand(rng::Random.AbstractRNG, ::Random.SamplerType{Nil}) = nil
45+
4346
end # module
4447

4548
using .NilNumber: Nil, nil

0 commit comments

Comments
 (0)