Skip to content

Commit aa5abed

Browse files
authored
Merge pull request #740 from ElOceanografo/spdiag
Add rrule for spdiagm
2 parents e3b8bf5 + 7a6d648 commit aa5abed

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ end
9696

9797
function _diagm_back(p, ȳ)
9898
k, v = p
99-
d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix
99+
d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix
100100
return Tangent{typeof(p)}(second = d)
101101
end
102102

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,26 @@ function rrule(::typeof(det), x::SparseMatrixCSC)
137137
end
138138
return Ω, det_pullback
139139
end
140+
141+
142+
function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)
143+
144+
function spdiagm_pullback(ȳ)
145+
return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
146+
end
147+
return spdiagm(m, n, kv...), spdiagm_pullback
148+
end
149+
150+
function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...)
151+
function spdiagm_pullback(ȳ)
152+
return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
153+
end
154+
return spdiagm(kv...), spdiagm_pullback
155+
end
156+
157+
function rrule(::typeof(spdiagm), v::AbstractVector)
158+
function spdiagm_pullback(ȳ)
159+
return (NoTangent(), diag(unthunk(ȳ)))
160+
end
161+
return spdiagm(v), spdiagm_pullback
162+
end

test/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,51 @@ end
1818
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4)
1919
end
2020

21+
# copied over from test/rulesets/LinearAlgebra/structured
22+
@testset "spdiagm" begin
23+
@testset "without size" begin
24+
M, N = 7, 9
25+
s = (8, 8)
26+
a = randn(M)
27+
b = randn(M)
28+
c = randn(M - 1)
29+
= randn(s)
30+
ps = (0 => a, 1 => b, 0 => c)
31+
y, back = rrule(spdiagm, ps...)
32+
@test y == spdiagm(ps...)
33+
∂self, ∂pa, ∂pb, ∂pc = back(ȳ)
34+
@test ∂self === NoTangent()
35+
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c)
36+
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
37+
∂px = unthunk(∂px)
38+
@test ∂px isa Tangent{typeof(p)}
39+
@test ∂px.first isa AbstractZero
40+
@test ∂px.second ∂x_fd
41+
end
42+
end
43+
@testset "with size" begin
44+
M, N = 7, 9
45+
a = randn(M)
46+
b = randn(M)
47+
c = randn(M - 1)
48+
= randn(M, N)
49+
ps = (0 => a, 1 => b, 0 => c)
50+
y, back = rrule(spdiagm, M, N, ps...)
51+
@test y == spdiagm(M, N, ps...)
52+
∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ)
53+
@test ∂self === NoTangent()
54+
@test ∂M === NoTangent()
55+
@test ∂N === NoTangent()
56+
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c)
57+
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
58+
∂px = unthunk(∂px)
59+
@test ∂px isa Tangent{typeof(p)}
60+
@test ∂px.first isa AbstractZero
61+
@test ∂px.second ∂x_fd
62+
end
63+
end
64+
end
65+
2166
@testset "findnz" begin
2267
A = sprand(5, 5, 0.5)
2368
dA = similar(A)
@@ -42,4 +87,4 @@ end
4287
test_rrule(logabsdet, A)
4388
test_rrule(logdet, A)
4489
test_rrule(det, A)
45-
end
90+
end

0 commit comments

Comments
 (0)