Skip to content

Commit 4da339e

Browse files
authored
fix some setup bugs (#2145)
1 parent ba48ad0 commit 4da339e

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

src/deprecations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ end
130130
_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
131131
const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
132132
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now
133-
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields
134-
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs
133+
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields
134+
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs
135135
const ClipGrad = Optimise.ClipValue
136136
_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred
137137

src/train.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Train
22

33
using LinearAlgebra
44
using Optimisers: Optimisers
5-
using Functors: fmap
5+
using Functors: fmap, fmapstructure
66

77
import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions
88

@@ -48,7 +48,8 @@ julia> opt_state # mutated by Flux.train!
4848
"""
4949
function setup(rule::Optimisers.AbstractRule, model)
5050
state = Optimisers.setup(rule, model)
51-
fmap(model, exclude = Optimisers.isnumeric) do x
51+
# This check only needs foreach; using fmap caused https://github.com/FluxML/Flux.jl/issues/2144
52+
fmapstructure(model, exclude = Optimisers.isnumeric) do x
5253
Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`.
5354
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""")
5455
end

test/train.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,10 @@ end
139139
@test diff1 diff3
140140
end
141141

142+
@testset "Flux.setup bugs" begin
143+
# https://github.com/FluxML/Flux.jl/issues/2144
144+
@test Flux.setup(Flux.Adam(), Embedding(3 => 1)).weight isa Optimisers.Leaf
145+
# Typo in 0.13.9's deprecation
146+
@test Flux.setup(Flux.ClipValue(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipGrad
147+
@test Flux.setup(Flux.ClipNorm(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipNorm
148+
end

0 commit comments

Comments
 (0)