Skip to content

Commit caa1cee

Browse files
fix component arrays test (#2419)
* fix component arrays test * import Optimisers
1 parent 66d2e7c commit caa1cee

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

test/optimise.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Flux.Optimise
22
using Flux.Optimise: runall
33
using Flux: Params, gradient
44
import FillArrays, ComponentArrays
5+
import Optimisers
56
using Test
67
using Random
78

@@ -167,21 +168,19 @@ end
167168
@testset "update!: handle ComponentArrays" begin
168169
w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2]))
169170
wold = deepcopy(w)
170-
θ = Flux.params([w])
171-
gs = gradient(() -> sum(w.a) + sum(w.c.b), θ)
172-
opt = Descent(0.1)
173-
Flux.update!(opt, θ, gs)
174-
@test w.a wold.a .- 0.1
171+
opt_state = Optimisers.setup(Optimisers.Descent(0.1), w)
172+
gs = gradient(w -> w.a + sum(w.c.b), w)[1]
173+
Flux.update!(opt_state, w, gs)
174+
@test w.a wold.a - 0.1
175175
@test w.b wold.b
176176
@test w.c.b wold.c.b .- 0.1
177177
@test w.c.a wold.c.a
178178

179179
w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2]))
180180
wold = deepcopy(w)
181-
θ = Flux.params([w])
182-
gs = gradient(() -> sum(w), θ)
183-
opt = Descent(0.1)
184-
Flux.update!(opt, θ, gs)
181+
opt_state = Optimisers.setup(Optimisers.Descent(0.1), w)
182+
gs = gradient(w -> sum(w), w)[1]
183+
Flux.update!(opt_state, w, gs)
185184
@test w wold .- 0.1
186185
end
187186

0 commit comments

Comments
 (0)