Skip to content

Commit 8587a07

Browse files
authored
Add rules for division by Cholesky (#631)
* Add rules for division by Cholesky * Add tests for new rules * Increment minor version number * Increment minor version number
1 parent 36f3ce7 commit 8587a07

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.40.0"
3+
version = "1.41.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,38 @@ function _x_divide_conj_y(x, y)
578578
z = x / conj(y)
579579
return iszero(x) ? zero(z) : z
580580
end
581+
582+
# these rules exists because the primals mutates using `ldiv!` and `rdiv!`
583+
function rrule(::typeof(\), A::Cholesky, B::AbstractVecOrMat{<:Union{Real,Complex}})
584+
U, getproperty_back = rrule(getproperty, A, :U)
585+
Z = U' \ B
586+
Y = U \ Z
587+
project_B = ProjectTo(B)
588+
function ldiv_Cholesky_AbsVecOrMat_pullback(ΔY)
589+
∂Z = U' \ ΔY
590+
∂B = U \ ∂Z
591+
∂A = Thunk() do
592+
_, Ā = getproperty_back(-add!!(∂Z * Y', Z * ∂B'))
593+
return
594+
end
595+
return NoTangent(), ∂A, project_B(∂B)
596+
end
597+
return Y, ldiv_Cholesky_AbsVecOrMat_pullback
598+
end
599+
600+
function rrule(::typeof(/), B::AbstractMatrix{<:Union{Real,Complex}}, A::Cholesky)
601+
U, getproperty_back = rrule(getproperty, A, :U)
602+
Z = B / U
603+
Y = Z / U'
604+
project_B = ProjectTo(B)
605+
function rdiv_AbstractMatrix_Cholesky_pullback(ΔY)
606+
∂Z = ΔY / U
607+
∂B = ∂Z / U'
608+
∂A = Thunk() do
609+
_, Ā = getproperty_back(-add!!(∂Z' * Y, Z' * ∂B))
610+
return
611+
end
612+
return NoTangent(), project_B(∂B), ∂A
613+
end
614+
return Y, rdiv_AbstractMatrix_Cholesky_pullback
615+
end

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,5 +521,29 @@ end
521521
@test ΔX.factors isa Diagonal && all(iszero, ΔX.factors)
522522
end
523523
end
524+
525+
@testset "\\(::Cholesky, ::AbstractVecOrMat)" begin
526+
n = 10
527+
for T in (Float64, ComplexF64), sz in (n, (n, 5))
528+
A = generate_well_conditioned_matrix(T, n)
529+
C = cholesky(A)
530+
B = randn(T, sz)
531+
# because the rule calls the rrule for getproperty, its rrule is not
532+
# completely type-inferrable
533+
test_rrule(\, C, B; check_inferred=false)
534+
end
535+
end
536+
537+
@testset "/(::AbstractMatrix, ::Cholesky)" begin
538+
n = 10
539+
for T in (Float64, ComplexF64)
540+
A = generate_well_conditioned_matrix(T, n)
541+
C = cholesky(A)
542+
B = randn(T, 5, n)
543+
# because the rule calls the rrule for getproperty, its rrule is not
544+
# completely type-inferrable
545+
test_rrule(/, B, C; check_inferred=false)
546+
end
547+
end
524548
end
525549
end

0 commit comments

Comments
 (0)