Skip to content

Commit 6a2a870

Browse files
committed
also test destructure
1 parent 25ac46b commit 6a2a870

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

test/destructure.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,32 @@ end
104104
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
105105
# Diffractor error in perform_optic_transform
106106
end
107+
108+
@testset "using Yota" begin
109+
@test_broken Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] # Unexpected expression: $(Expr(:static_parameter, 1))
110+
# These are all broken!
111+
@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
112+
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], ZeroTangent())
113+
@test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = ZeroTangent(), z = [0,0,0])
114+
@test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = ZeroTangent(), z = [0,0,0])
115+
116+
g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1]
117+
@test g5.a[1].x == [0,0,1]
118+
@test g5.a[2] === ZeroTangent()
119+
120+
g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1]
121+
@test g6.a == [0,0,0]
122+
@test g6.a isa Vector{Float64}
123+
@test g6.b == [0+im]
124+
125+
g8 = Yota_gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
126+
@test g8[1].x == [2,4,6]
127+
@test g8[2].b.x == [8]
128+
@test g8[3] == [[10.0]]
129+
130+
g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
131+
@test g9.c === ZeroTangent()
132+
end
107133
end
108134

109135
@testset "gradient of rebuild" begin
@@ -149,6 +175,36 @@ end
149175
# Not fixed by this:
150176
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
151177
end
178+
179+
@testset "using Yota" begin
180+
re1 = destructure(m1)[2]
181+
@test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
182+
re2 = destructure(m2)[2]
183+
@test Yota_gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0]
184+
re3 = destructure(m3)[2]
185+
@test Yota_gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0]
186+
@test Yota_gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0]
187+
188+
re4 = destructure(m4)[2]
189+
@test Yota_gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0]
190+
@test Yota_gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0]
191+
@test Yota_gradient(rand(6)) do x
192+
m = re4(x)
193+
m.x[1] + 2*m.y[2] + 3*m.z[3]
194+
end[1] == [1,2,0, 0,0,3]
195+
196+
re7 = destructure(m7)[2]
197+
@test Yota_gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1]
198+
@test Yota_gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
199+
@test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]
200+
201+
v8, re8 = destructure(m8)
202+
@test_broken Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] # MethodError: no method matching zero(::Type{Any})
203+
@test_broken Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # MethodError: no method matching !(::Expr)
204+
205+
re9 = destructure(m9)[2]
206+
@test_broken Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] # MethodError: no method matching zero(::Type{Array})
207+
end
152208
end
153209

154210
@testset "Flux issue 1826" begin

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ function Optimisers.apply!(o::BiRule, state, x, dx, dx2)
3737
return state, dx
3838
end
3939

40+
# Make Yota's output look like Zygote's:
41+
42+
Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2])
43+
4044
@testset verbose=true "Optimisers.jl" begin
4145
@testset verbose=true "Features" begin
4246

0 commit comments

Comments
 (0)