Skip to content

Commit 3b29953

Browse files
committed
rrule and tests for spdiagm
1 parent a2005d1 commit 3b29953

File tree

2 files changed

+73
-10
lines changed

2 files changed

+73
-10
lines changed

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,33 @@ function rrule(::typeof(det), x::SparseMatrixCSC)
137137
end
138138
return Ω, det_pullback
139139
end
140+
141+
142+
function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)
143+
144+
function spdiagm_pullback(ȳ)
145+
return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
146+
end
147+
return spdiagm(m, n, kv...), spdiagm_pullback
148+
end
149+
150+
function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...)
151+
function spdiagm_pullback(ȳ)
152+
return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
153+
end
154+
return spdiagm(kv...), spdiagm_pullback
155+
end
156+
157+
function rrule(::typeof(spdiagm), v::AbstractVector)
158+
function spdiagm_pullback(ȳ)
159+
return (NoTangent(), diag(unthunk(ȳ)))
160+
end
161+
return spdiagm(v), spdiagm_pullback
162+
end
163+
164+
165+
function _diagm_back(p, ȳ)
166+
k, v = p
167+
d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix
168+
return Tangent{typeof(p)}(second = d)
169+
end

test/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,49 @@ end
1818
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4)
1919
end
2020

21+
# copied over from test/rulesets/LinearAlgebra/structured
2122
@testset "spdiagm" begin
22-
@test 1 == 1
23-
m = 5
24-
n = 4
25-
v1 = ones(m)
26-
v2 = ones(n)
27-
test_rrule(spdiagm, m, n, 0 => v2)
28-
29-
# test_rrule(spdiagm, 0 => v1)
30-
# test_rrule(spdiagm, v1)
23+
@testset "without size" begin
24+
M, N = 7, 9
25+
s = (8, 8)
26+
a, ā = randn(M), randn(M)
27+
b, b̄ = randn(M), randn(M)
28+
c, c̄ = randn(M - 1), randn(M - 1)
29+
= randn(s)
30+
ps = (0 => a, 1 => b, 0 => c)
31+
y, back = rrule(spdiagm, ps...)
32+
@test y == spdiagm(ps...)
33+
∂self, ∂pa, ∂pb, ∂pc = back(ȳ)
34+
@test ∂self === NoTangent()
35+
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c)
36+
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
37+
∂px = unthunk(∂px)
38+
@test ∂px isa Tangent{typeof(p)}
39+
@test ∂px.first isa AbstractZero
40+
@test ∂px.second ∂x_fd
41+
end
42+
end
43+
@testset "with size" begin
44+
M, N = 7, 9
45+
a, ā = randn(M), randn(M)
46+
b, b̄ = randn(M), randn(M)
47+
c, c̄ = randn(M - 1), randn(M - 1)
48+
= randn(M, N)
49+
ps = (0 => a, 1 => b, 0 => c)
50+
y, back = rrule(spdiagm, M, N, ps...)
51+
@test y == spdiagm(M, N, ps...)
52+
∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ)
53+
@test ∂self === NoTangent()
54+
@test ∂M === NoTangent()
55+
@test ∂N === NoTangent()
56+
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c)
57+
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
58+
∂px = unthunk(∂px)
59+
@test ∂px isa Tangent{typeof(p)}
60+
@test ∂px.first isa AbstractZero
61+
@test ∂px.second ∂x_fd
62+
end
63+
end
3164
end
3265

3366
@testset "findnz" begin
@@ -54,4 +87,4 @@ end
5487
test_rrule(logabsdet, A)
5588
test_rrule(logdet, A)
5689
test_rrule(det, A)
57-
end
90+
end

0 commit comments

Comments
 (0)