@@ -457,13 +457,21 @@ end
457
457
@testset " Symmetric" begin
458
458
X = generate_well_conditioned_matrix (10 )
459
459
F, dX_pullback = rrule (cholesky, X, Val (false ))
460
-
461
- X_symmetric, sym_back = rrule (Symmetric, X, :U )
462
- C, chol_back_sym = rrule (cholesky, X_symmetric, Val (false ))
463
-
464
- Δ = Tangent {typeof(C)} ((factors= randn (size (X))))
465
- ΔX_symmetric = chol_back_sym (Δ)[2 ]
466
- @test sym_back (ΔX_symmetric)[2 ] ≈ dX_pullback (Δ)[2 ]
460
+ ΔU = randn (size (X))
461
+ ΔF = Tangent {typeof(F)} (; factors= ΔU)
462
+
463
+ @testset for uplo in (:L , :U )
464
+ X_symmetric, sym_back = rrule (Symmetric, X, uplo)
465
+ C, chol_back_sym = rrule (cholesky, X_symmetric, Val (false ))
466
+
467
+ ΔC = Tangent {typeof(C)} (; factors= (uplo === :U ? ΔU : ΔU' ))
468
+ ΔX_symmetric = chol_back_sym (ΔC)[2 ]
469
+ if uplo === :U
470
+ @test sym_back (ΔX_symmetric)[2 ] ≈ dX_pullback (ΔF)[2 ]
471
+ else
472
+ @test sym_back (ΔX_symmetric)[2 ] ≈ dX_pullback (ΔF)[2 ]'
473
+ end
474
+ end
467
475
end
468
476
469
477
# Ensure that cotangents of cholesky(::StridedMatrix) and
@@ -472,13 +480,21 @@ end
472
480
@testset " Hermitian{$T }" for T in (Float64, ComplexF64)
473
481
X = generate_well_conditioned_matrix (T, 10 )
474
482
F, dX_pullback = rrule (cholesky, X, Val (false ))
475
-
476
- X_hermitian, herm_back = rrule (Hermitian, X, :U )
477
- C, chol_back_herm = rrule (cholesky, X_hermitian, Val (false ))
478
-
479
- Δ = Tangent {typeof(C)} ((factors= randn (T, size (X))))
480
- ΔX_hermitian = chol_back_herm (Δ)[2 ]
481
- @test herm_back (ΔX_hermitian)[2 ] ≈ dX_pullback (Δ)[2 ]
483
+ ΔU = randn (T, size (X))
484
+ ΔF = Tangent {typeof(F)} (; factors= ΔU)
485
+
486
+ @testset for uplo in (:L , :U )
487
+ X_hermitian, herm_back = rrule (Hermitian, X, uplo)
488
+ C, chol_back_herm = rrule (cholesky, X_hermitian, Val (false ))
489
+
490
+ ΔC = Tangent {typeof(C)} (; factors= (uplo === :U ? ΔU : ΔU' ))
491
+ ΔX_hermitian = chol_back_herm (ΔC)[2 ]
492
+ if uplo === :U
493
+ @test herm_back (ΔX_hermitian)[2 ] ≈ dX_pullback (ΔF)[2 ]
494
+ else
495
+ @test herm_back (ΔX_hermitian)[2 ] ≈ dX_pullback (ΔF)[2 ]'
496
+ end
497
+ end
482
498
end
483
499
@testset " check has correct default and passed to primal" begin
484
500
# this will almost certainly be a non-PD matrix
0 commit comments