Skip to content

Commit 1fe64dc

Browse files
add another ComponentArray test
1 parent 6bae1e7 commit 1fe64dc

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1010
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
1111
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
12+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1213
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1314
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1415
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"

test/optimise.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,12 @@ end
181181
@test w.b wold.b
182182
@test w.c.b wold.c.b .- 0.1
183183
@test w.c.a wold.c.a
184+
185+
w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2]))
186+
wold = deepcopy(w)
187+
θ = Flux.params([w])
188+
gs = gradient(() -> sum(w), θ)
189+
opt = Descent(0.1)
190+
Flux.update!(opt, θ, gs)
191+
@test w wold .- 0.1
184192
end

0 commit comments

Comments
 (0)