Skip to content

Commit a50b18c

Browse files
authored
narrower non_differentiable params (#2118)
1 parent 0a765b1 commit a50b18c

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ ProgressLogging = "0.1"
3838
Reexport = "0.2, 1.0"
3939
SpecialFunctions = "1.8.2, 2.1.2"
4040
StatsBase = "0.33"
41-
Zygote = "0.6.34"
41+
Zygote = "0.6.49"
4242
julia = "1.6"
4343

4444
[extras]

src/functor.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ function params(m...)
8989
end
9090

9191
# Allows caching of the parameters when params is called within gradient() to fix #2040.
92-
@non_differentiable params(m...)
92+
# @non_differentiable params(m...) # https://github.com/FluxML/Flux.jl/pull/2054
93+
# That speeds up implicit use, and silently breaks explicit use.
94+
# From @macroexpand Zygote.@nograd params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248
95+
Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing
9396

9497
struct FluxCUDAAdaptor end
9598
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)

test/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,20 @@ end
270270
@test size.(Flux.params(m)) == [(2,), (1, 2)]
271271
end
272272

273+
@testset "params gradient" begin
274+
m = (x=[1,2.0], y=[3.0]);
275+
276+
# Explicit -- was broken by #2054
277+
gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1]
278+
@test gnew.x [0.4472135954999579, 0.8944271909999159]
279+
@test gnew.y [1.0]
280+
281+
# Implicit
282+
gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m))
283+
@test gold[m.x] [0.4472135954999579, 0.8944271909999159]
284+
@test gold[m.y] [1.0]
285+
end
286+
273287
@testset "Precision" begin
274288
m = Chain(Dense(10, 5, relu), Dense(5, 2))
275289
x64 = rand(Float64, 10)

0 commit comments

Comments
 (0)