Skip to content

Commit 6bae1e7

Browse files
add test for ComponentArrays
1 parent 28bd53e commit 6bae1e7

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

Manifest.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ version = "3.30.0"
9595
deps = ["Artifacts", "Libdl"]
9696
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
9797

98+
[[ComponentArrays]]
99+
deps = ["ArrayInterface", "LinearAlgebra", "Requires"]
100+
git-tree-sha1 = "76495e7a7e47abc3771d70c782d5f6e66f114d36"
101+
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
102+
version = "0.10.5"
103+
98104
[[DataAPI]]
99105
git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d"
100106
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,12 @@ Zygote = "0.6"
4646
julia = "1.6"
4747

4848
[extras]
49+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
4950
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
5051
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
5152
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
5253
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5354
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5455

5556
[targets]
56-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays"]
57+
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]

test/optimise.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Flux.Optimise
22
using Flux.Optimise: runall
33
using Flux: Params, gradient
4-
import FillArrays
4+
import FillArrays, ComponentArrays
55
using Test
66
using Random
77

@@ -143,7 +143,7 @@ end
143143
@test norm(w̄_norm) <= 1
144144
end
145145

146-
@testset "handle Fills from Zygote" begin
146+
@testset "update!: handle Fills from Zygote" begin
147147
w = randn(10,10)
148148
wold = copy(w)
149149
g = FillArrays.Ones(size(w))
@@ -169,3 +169,16 @@ end
169169
Flux.update!(opt, θ, gs)
170170
@test w wold .- 0.1
171171
end
172+
173+
@testset "update!: handle ComponentArrays" begin
174+
w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2]))
175+
wold = deepcopy(w)
176+
θ = Flux.params([w])
177+
gs = gradient(() -> sum(w.a) + sum(w.c.b), θ)
178+
opt = Descent(0.1)
179+
Flux.update!(opt, θ, gs)
180+
@test w.a wold.a .- 0.1
181+
@test w.b wold.b
182+
@test w.c.b wold.c.b .- 0.1
183+
@test w.c.a wold.c.a
184+
end

0 commit comments

Comments
 (0)