Skip to content

Commit 2c1bce7

Browse files
authored
Add tests for cholesky with lower storage (#632)
1 parent 8587a07 commit 2c1bce7

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,21 @@ end
457457
@testset "Symmetric" begin
458458
X = generate_well_conditioned_matrix(10)
459459
F, dX_pullback = rrule(cholesky, X, Val(false))
460-
461-
X_symmetric, sym_back = rrule(Symmetric, X, :U)
462-
C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false))
463-
464-
Δ = Tangent{typeof(C)}((factors=randn(size(X))))
465-
ΔX_symmetric = chol_back_sym(Δ)[2]
466-
@test sym_back(ΔX_symmetric)[2] dX_pullback(Δ)[2]
460+
ΔU = randn(size(X))
461+
ΔF = Tangent{typeof(F)}(; factors=ΔU)
462+
463+
@testset for uplo in (:L, :U)
464+
X_symmetric, sym_back = rrule(Symmetric, X, uplo)
465+
C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false))
466+
467+
ΔC = Tangent{typeof(C)}(; factors=(uplo === :U ? ΔU : ΔU'))
468+
ΔX_symmetric = chol_back_sym(ΔC)[2]
469+
if uplo === :U
470+
@test sym_back(ΔX_symmetric)[2] dX_pullback(ΔF)[2]
471+
else
472+
@test sym_back(ΔX_symmetric)[2] dX_pullback(ΔF)[2]'
473+
end
474+
end
467475
end
468476

469477
# Ensure that cotangents of cholesky(::StridedMatrix) and
@@ -472,13 +480,21 @@ end
472480
@testset "Hermitian{$T}" for T in (Float64, ComplexF64)
473481
X = generate_well_conditioned_matrix(T, 10)
474482
F, dX_pullback = rrule(cholesky, X, Val(false))
475-
476-
X_hermitian, herm_back = rrule(Hermitian, X, :U)
477-
C, chol_back_herm = rrule(cholesky, X_hermitian, Val(false))
478-
479-
Δ = Tangent{typeof(C)}((factors=randn(T, size(X))))
480-
ΔX_hermitian = chol_back_herm(Δ)[2]
481-
@test herm_back(ΔX_hermitian)[2] dX_pullback(Δ)[2]
483+
ΔU = randn(T, size(X))
484+
ΔF = Tangent{typeof(F)}(; factors=ΔU)
485+
486+
@testset for uplo in (:L, :U)
487+
X_hermitian, herm_back = rrule(Hermitian, X, uplo)
488+
C, chol_back_herm = rrule(cholesky, X_hermitian, Val(false))
489+
490+
ΔC = Tangent{typeof(C)}(; factors=(uplo === :U ? ΔU : ΔU'))
491+
ΔX_hermitian = chol_back_herm(ΔC)[2]
492+
if uplo === :U
493+
@test herm_back(ΔX_hermitian)[2] dX_pullback(ΔF)[2]
494+
else
495+
@test herm_back(ΔX_hermitian)[2] dX_pullback(ΔF)[2]'
496+
end
497+
end
482498
end
483499
@testset "check has correct default and passed to primal" begin
484500
# this will almost certainly be a non-PD matrix

0 commit comments

Comments
 (0)