|
79 | 79 | end
|
80 | 80 | end
|
81 | 81 | @testset "svd" begin
|
82 |
| - for n in [4, 6, 10], m in [3, 5, 10] |
83 |
| - X = randn(n, m) |
84 |
| - F, dX_pullback = rrule(svd, X) |
85 |
| - for p in [:U, :S, :V] |
86 |
| - Y, dF_pullback = rrule(getproperty, F, p) |
87 |
| - Ȳ = randn(size(Y)...) |
88 |
| - |
89 |
| - dself1, dF, dp = dF_pullback(Ȳ) |
90 |
| - @test dself1 === NO_FIELDS |
91 |
| - @test dp === DoesNotExist() |
92 |
| - |
93 |
| - dself2, dX = dX_pullback(dF) |
94 |
| - @test dself2 === NO_FIELDS |
95 |
| - X̄_ad = unthunk(dX) |
96 |
| - X̄_fd = only(j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)) |
97 |
| - @test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6)) |
| 82 | + for n in [4, 6, 10], m in [3, 5, 9] |
| 83 | + @testset "($n x $m) svd" begin |
| 84 | + X = randn(n, m) |
| 85 | + @show X |
| 86 | + test_rrule(svd, X; atol=1e-6, rtol=1e-6) |
98 | 87 | end
|
99 |
| - @testset "Vt" begin |
100 |
| - Y, dF_pullback = rrule(getproperty, F, :Vt) |
101 |
| - Ȳ = randn(size(Y)...) |
102 |
| - @test_throws ArgumentError dF_pullback(Ȳ) |
| 88 | + end |
| 89 | + |
| 90 | + for n in [4, 6, 10], m in [3, 5, 10] |
| 91 | + @testset "($n x $m) getproperty" begin |
| 92 | + X = randn(n, m) |
| 93 | + F = svd(X) |
| 94 | + rand_adj = adjoint(rand(reverse(size(F.V))...)) |
| 95 | + |
| 96 | + test_rrule(getproperty, F, :U ⊢ nothing; check_inferred=false) |
| 97 | + test_rrule(getproperty, F, :S ⊢ nothing; check_inferred=false) |
| 98 | + test_rrule(getproperty, F, :Vt ⊢ nothing; check_inferred=false) |
| 99 | + test_rrule(getproperty, F, :V ⊢ nothing; check_inferred=false, output_tangent=rand_adj) |
103 | 100 | end
|
104 | 101 | end
|
105 | 102 |
|
106 | 103 | @testset "Thunked inputs" begin
|
107 | 104 | X = randn(4, 3)
|
108 | 105 | F, dX_pullback = rrule(svd, X)
|
109 |
| - for p in [:U, :S, :V] |
| 106 | + for p in [:U, :S, :V, :Vt] |
110 | 107 | Y, dF_pullback = rrule(getproperty, F, p)
|
111 | 108 | Ȳ = randn(size(Y)...)
|
112 | 109 |
|
113 | 110 | _, dF_unthunked, _ = dF_pullback(Ȳ)
|
114 | 111 |
|
115 | 112 | # helper to let us check how things are stored.
|
116 |
| - backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p) |
| 113 | + p_access = p == :V ? :Vt : p |
| 114 | + backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p_access) |
117 | 115 | @assert !(backing_field(dF_unthunked, p) isa AbstractThunk)
|
118 | 116 |
|
119 | 117 | dF_thunked = map(f->Thunk(()->f), dF_unthunked)
|
|
126 | 124 | end
|
127 | 125 | end
|
128 | 126 |
|
129 |
| - @testset "+" begin |
130 |
| - X = [1.0 2.0; 3.0 4.0; 5.0 6.0] |
131 |
| - F, dX_pullback = rrule(svd, X) |
132 |
| - X̄ = Composite{typeof(F)}(U=zeros(3, 2), S=zeros(2), V=zeros(2, 2)) |
133 |
| - for p in [:U, :S, :V] |
134 |
| - Y, dF_pullback = rrule(getproperty, F, p) |
135 |
| - Ȳ = ones(size(Y)...) |
136 |
| - dself, dF, dp = dF_pullback(Ȳ) |
137 |
| - @test dself === NO_FIELDS |
138 |
| - @test dp === DoesNotExist() |
139 |
| - X̄ += dF |
140 |
| - end |
141 |
| - @test X̄.U ≈ ones(3, 2) atol=1e-6 |
142 |
| - @test X̄.S ≈ ones(2) atol=1e-6 |
143 |
| - @test X̄.V ≈ ones(2, 2) atol=1e-6 |
144 |
| - end |
145 |
| - |
146 | 127 | @testset "Helper functions" begin
|
147 | 128 | X = randn(10, 10)
|
148 | 129 | Y = randn(10, 10)
|
|
0 commit comments