Skip to content

Commit 4afdfdd

Browse files
tkfoxinabox
authored andcommitted
Add rrule for Diagonal * AbstractVector (#108)
* Add rrule for Diagonal * AbstractVector * Simplify rrule definitions for Diagonal * AbstractVector Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> * Use consistent indentation
1 parent 4240832 commit 4afdfdd

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ function rrule(::typeof(diag), A::AbstractMatrix)
1919
return diag(A), diag_pullback
2020
end
2121

22+
function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real})
23+
function times_pullback(Ȳ)
24+
return (NO_FIELDS, @thunk(Diagonal(Ȳ .* V)), @thunk(D * Ȳ))
25+
end
26+
return D * V, times_pullback
27+
end
28+
2229
#####
2330
##### `Symmetric`
2431
#####

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
# Concrete type instead of UnionAll
88
rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N)))
99
end
10+
@testset "::Diagonal * ::AbstractVector" begin
11+
rng, N = MersenneTwister(123456), 3
12+
rrule_test(
13+
*,
14+
randn(rng, N),
15+
(Diagonal(randn(rng, N)), Diagonal(randn(rng, N))),
16+
(randn(rng, N), randn(rng, N)),
17+
)
18+
end
1019
@testset "diag" begin
1120
rng, N = MersenneTwister(123456), 7
1221
rrule_test(diag, randn(rng, N), (randn(rng, N, N), randn(rng, N, N)))

0 commit comments

Comments
 (0)