@@ -2,6 +2,7 @@ using Flux.Optimise
2
2
using Flux. Optimise: runall
3
3
using Flux: Params, gradient
4
4
import FillArrays, ComponentArrays
5
+ import Optimisers
5
6
using Test
6
7
using Random
7
8
@@ -167,21 +168,19 @@ end
167
168
@testset " update!: handle ComponentArrays" begin
168
169
w = ComponentArrays. ComponentArray (a= 1.0 , b= [2 , 1 , 4 ], c= (a= 2 , b= [1 , 2 ]))
169
170
wold = deepcopy (w)
170
- θ = Flux. params ([w])
171
- gs = gradient (() -> sum (w. a) + sum (w. c. b), θ)
172
- opt = Descent (0.1 )
173
- Flux. update! (opt, θ, gs)
174
- @test w. a ≈ wold. a .- 0.1
171
+ opt_state = Optimisers. setup (Optimisers. Descent (0.1 ), w)
172
+ gs = gradient (w -> w. a + sum (w. c. b), w)[1 ]
173
+ Flux. update! (opt_state, w, gs)
174
+ @test w. a ≈ wold. a - 0.1
175
175
@test w. b ≈ wold. b
176
176
@test w. c. b ≈ wold. c. b .- 0.1
177
177
@test w. c. a ≈ wold. c. a
178
178
179
179
w = ComponentArrays. ComponentArray (a= 1.0 , b= [2 , 1 , 4 ], c= (a= 2 , b= [1 , 2 ]))
180
180
wold = deepcopy (w)
181
- θ = Flux. params ([w])
182
- gs = gradient (() -> sum (w), θ)
183
- opt = Descent (0.1 )
184
- Flux. update! (opt, θ, gs)
181
+ opt_state = Optimisers. setup (Optimisers. Descent (0.1 ), w)
182
+ gs = gradient (w -> sum (w), w)[1 ]
183
+ Flux. update! (opt_state, w, gs)
185
184
@test w ≈ wold .- 0.1
186
185
end
187
186
0 commit comments