Skip to content

Commit 649bfbb

Browse files
thomasgudjonwrightMiha Zgubic
andauthored
Add rules for sqrt(::Diagonal) (#509)
* WIP add diag sqrt rule * Adding test with different rand_tangent() * michaels less allocating suggestion * spaces and unthunk * rand the tests * fix CI * prevent stochastic failures Co-authored-by: Miha Zgubic <miha.zgubic@invenialabs.co.uk>
1 parent 26e4608 commit 649bfbb

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.9.0"
3+
version = "1.10.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ end
3535
##### `Diagonal`
3636
#####
3737

38+
_diagview(x::Diagonal) = x.diag
39+
_diagview(x::AbstractMatrix) = view(x, diagind(x))
40+
_diagview(x::Tangent{<:Diagonal}) = x.diag
41+
function ChainRulesCore.rrule(::typeof(sqrt), d::Diagonal)
42+
y = sqrt(d)
43+
@assert y isa Diagonal
44+
function sqrt_pullback(Δ)
45+
Δ_diag = _diagview(unthunk(Δ))
46+
return NoTangent(), Diagonal(Δ_diag ./ (2 .* y.diag))
47+
end
48+
return y, sqrt_pullback
49+
end
50+
3851
# these functions are defined outside the rrule because otherwise type inference breaks
3952
# see https://github.com/JuliaLang/julia/issues/40990
4053
_Diagonal_pullback(ȳ::AbstractMatrix) = return (NoTangent(), diag(ȳ)) # should we emit a warning here? this shouldn't be called if project works right

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@
8787
end
8888
end
8989
end
90+
@testset "sqrt(::Diagonal)" begin
91+
test_rrule(sqrt, Diagonal(rand(3)))
92+
test_rrule(sqrt, Diagonal([1.0, 2]); output_tangent=[1.2 3.4; 1.2 4.3])
93+
end
9094
@testset "$f, $T" for
9195
f in (Adjoint, adjoint, Transpose, transpose),
9296
T in (Float64, ComplexF64)

0 commit comments

Comments
 (0)