|
104 | 104 | # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
|
105 | 105 | # Diffractor error in perform_optic_transform
|
106 | 106 | end
|
| 107 | + |
| 108 | + @testset "using Yota" begin |
| 109 | + @test_broken Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] # Unexpected expression: $(Expr(:static_parameter, 1)) |
| 110 | + # These are all broken! |
| 111 | + @test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) |
| 112 | + @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], ZeroTangent()) |
| 113 | + @test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = ZeroTangent(), z = [0,0,0]) |
| 114 | + @test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = ZeroTangent(), z = [0,0,0]) |
| 115 | + |
| 116 | + g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1] |
| 117 | + @test g5.a[1].x == [0,0,1] |
| 118 | + @test g5.a[2] === ZeroTangent() |
| 119 | + |
| 120 | + g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1] |
| 121 | + @test g6.a == [0,0,0] |
| 122 | + @test g6.a isa Vector{Float64} |
| 123 | + @test g6.b == [0+im] |
| 124 | + |
| 125 | + g8 = Yota_gradient(m -> sum(abs2, destructure(m)[1]), m8)[1] |
| 126 | + @test g8[1].x == [2,4,6] |
| 127 | + @test g8[2].b.x == [8] |
| 128 | + @test g8[3] == [[10.0]] |
| 129 | + |
| 130 | + g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1] |
| 131 | + @test g9.c === ZeroTangent() |
| 132 | + end |
107 | 133 | end
|
108 | 134 |
|
109 | 135 | @testset "gradient of rebuild" begin
|
|
149 | 175 | # Not fixed by this:
|
150 | 176 | # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
|
151 | 177 | end
|
| 178 | + |
| 179 | + @testset "using Yota" begin |
| 180 | + re1 = destructure(m1)[2] |
| 181 | + @test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] |
| 182 | + re2 = destructure(m2)[2] |
| 183 | + @test Yota_gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] |
| 184 | + re3 = destructure(m3)[2] |
| 185 | + @test Yota_gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] |
| 186 | + @test Yota_gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] |
| 187 | + |
| 188 | + re4 = destructure(m4)[2] |
| 189 | + @test Yota_gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] |
| 190 | + @test Yota_gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] |
| 191 | + @test Yota_gradient(rand(6)) do x |
| 192 | + m = re4(x) |
| 193 | + m.x[1] + 2*m.y[2] + 3*m.z[3] |
| 194 | + end[1] == [1,2,0, 0,0,3] |
| 195 | + |
| 196 | + re7 = destructure(m7)[2] |
| 197 | + @test Yota_gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] |
| 198 | + @test Yota_gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] |
| 199 | + @test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] |
| 200 | + |
| 201 | + v8, re8 = destructure(m8) |
| 202 | + @test_broken Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] # MethodError: no method matching zero(::Type{Any}) |
| 203 | + @test_broken Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # MethodError: no method matching !(::Expr) |
| 204 | + |
| 205 | + re9 = destructure(m9)[2] |
| 206 | + @test_broken Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] # MethodError: no method matching zero(::Type{Array}) |
| 207 | + end |
152 | 208 | end
|
153 | 209 |
|
154 | 210 | @testset "Flux issue 1826" begin
|
|
0 commit comments