Skip to content

Commit 67c106f

Browse files
authored
Merge pull request #390 from JuliaDiff/mz/svd2
Fix Composite of SVD
2 parents 76ef95c + 9e72757 commit 67c106f

File tree

3 files changed

+24
-44
lines changed

3 files changed

+24
-44
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.54"
3+
version = "0.7.55"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
207207
F = svd(X)
208208
function svd_pullback::Composite)
209209
# `getproperty` on `Composite`s ensures we have no thunks.
210-
∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)
210+
∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt')
211211
return (NO_FIELDS, ∂X)
212212
end
213213
return F, svd_pullback
@@ -221,10 +221,9 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
221221
elseif x === :S
222222
C(S=Ȳ,)
223223
elseif x === :V
224-
C(V=Ȳ,)
224+
C(Vt=Ȳ',)
225225
elseif x === :Vt
226-
# TODO: https://github.com/JuliaDiff/ChainRules.jl/issues/106
227-
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
226+
C(Vt=Ȳ,)
228227
end
229228
return NO_FIELDS, ∂F, DoesNotExist()
230229
end

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -79,41 +79,39 @@ end
7979
end
8080
end
8181
@testset "svd" begin
82-
for n in [4, 6, 10], m in [3, 5, 10]
83-
X = randn(n, m)
84-
F, dX_pullback = rrule(svd, X)
85-
for p in [:U, :S, :V]
86-
Y, dF_pullback = rrule(getproperty, F, p)
87-
= randn(size(Y)...)
88-
89-
dself1, dF, dp = dF_pullback(Ȳ)
90-
@test dself1 === NO_FIELDS
91-
@test dp === DoesNotExist()
92-
93-
dself2, dX = dX_pullback(dF)
94-
@test dself2 === NO_FIELDS
95-
X̄_ad = unthunk(dX)
96-
X̄_fd = only(j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X))
97-
@test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6))
82+
for n in [4, 6, 10], m in [3, 5, 9]
83+
@testset "($n x $m) svd" begin
84+
X = randn(n, m)
85+
@show X
86+
test_rrule(svd, X; atol=1e-6, rtol=1e-6)
9887
end
99-
@testset "Vt" begin
100-
Y, dF_pullback = rrule(getproperty, F, :Vt)
101-
= randn(size(Y)...)
102-
@test_throws ArgumentError dF_pullback(Ȳ)
88+
end
89+
90+
for n in [4, 6, 10], m in [3, 5, 10]
91+
@testset "($n x $m) getproperty" begin
92+
X = randn(n, m)
93+
F = svd(X)
94+
rand_adj = adjoint(rand(reverse(size(F.V))...))
95+
96+
test_rrule(getproperty, F, :U nothing; check_inferred=false)
97+
test_rrule(getproperty, F, :S nothing; check_inferred=false)
98+
test_rrule(getproperty, F, :Vt nothing; check_inferred=false)
99+
test_rrule(getproperty, F, :V nothing; check_inferred=false, output_tangent=rand_adj)
103100
end
104101
end
105102

106103
@testset "Thunked inputs" begin
107104
X = randn(4, 3)
108105
F, dX_pullback = rrule(svd, X)
109-
for p in [:U, :S, :V]
106+
for p in [:U, :S, :V, :Vt]
110107
Y, dF_pullback = rrule(getproperty, F, p)
111108
= randn(size(Y)...)
112109

113110
_, dF_unthunked, _ = dF_pullback(Ȳ)
114111

115112
# helper to let us check how things are stored.
116-
backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p)
113+
p_access = p == :V ? :Vt : p
114+
backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p_access)
117115
@assert !(backing_field(dF_unthunked, p) isa AbstractThunk)
118116

119117
dF_thunked = map(f->Thunk(()->f), dF_unthunked)
@@ -126,23 +124,6 @@ end
126124
end
127125
end
128126

129-
@testset "+" begin
130-
X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
131-
F, dX_pullback = rrule(svd, X)
132-
= Composite{typeof(F)}(U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
133-
for p in [:U, :S, :V]
134-
Y, dF_pullback = rrule(getproperty, F, p)
135-
= ones(size(Y)...)
136-
dself, dF, dp = dF_pullback(Ȳ)
137-
@test dself === NO_FIELDS
138-
@test dp === DoesNotExist()
139-
+= dF
140-
end
141-
@test.U ones(3, 2) atol=1e-6
142-
@test.S ones(2) atol=1e-6
143-
@test.V ones(2, 2) atol=1e-6
144-
end
145-
146127
@testset "Helper functions" begin
147128
X = randn(10, 10)
148129
Y = randn(10, 10)

0 commit comments

Comments
 (0)