Skip to content

Commit 1f4a8a9

Browse files
authored
Use new cholesky pivot syntax in v1.8 (#633)
* Use new cholesky pivot syntax in v1.8 * Increment patch number * Increment minor version number
1 parent 2c1bce7 commit 1f4a8a9

File tree

3 files changed

+25
-21
lines changed

3 files changed

+25
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.41.0"
3+
version = "1.42.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
1818
const LU_RowMaximum = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum : Val{true}
1919
const LU_NoPivot = VERSION >= v"1.7.0-DEV.1188" ? NoPivot : Val{false}
2020

21+
const CHOLESKY_NoPivot = VERSION >= v"1.8.0-rc1" ? Union{NoPivot, Val{false}} : Val{false}
22+
2123
function frule(
2224
(_, Ȧ), ::typeof(lu!), A::StridedMatrix, pivot::Union{LU_RowMaximum,LU_NoPivot}; kwargs...
2325
)
@@ -462,8 +464,8 @@ function _cholesky_Diagonal_pullback(ΔC, C)
462464
end
463465
return NoTangent(), Diagonal(Ādiag), NoTangent()
464466
end
465-
function rrule(::typeof(cholesky), A::Diagonal{<:Number}, ::Val{false}; check::Bool=true)
466-
C = cholesky(A, Val(false); check=check)
467+
function rrule(::typeof(cholesky), A::Diagonal{<:Number}, pivot::CHOLESKY_NoPivot; check::Bool=true)
468+
C = cholesky(A, pivot; check=check)
467469
cholesky_pullback(ȳ) = _cholesky_Diagonal_pullback(unthunk(ȳ), C)
468470
return C, cholesky_pullback
469471
end
@@ -474,10 +476,10 @@ end
474476
function rrule(
475477
::typeof(cholesky),
476478
A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StridedMatrix},
477-
::Val{false};
479+
pivot::CHOLESKY_NoPivot;
478480
check::Bool=true,
479481
)
480-
C = cholesky(A, Val(false); check=check)
482+
C = cholesky(A, pivot; check=check)
481483
function cholesky_HermOrSym_pullback(ΔC)
482484
= _cholesky_pullback_shared_code(C, unthunk(ΔC))
483485
rmul!(Ā, one(eltype(Ā)) / 2)
@@ -489,10 +491,10 @@ end
489491
function rrule(
490492
::typeof(cholesky),
491493
A::StridedMatrix{<:Union{Real,Complex}},
492-
::Val{false};
494+
pivot::CHOLESKY_NoPivot;
493495
check::Bool=true,
494496
)
495-
C = cholesky(A, Val(false); check=check)
497+
C = cholesky(A, pivot; check=check)
496498
function cholesky_Strided_pullback(ΔC)
497499
= _cholesky_pullback_shared_code(C, unthunk(ΔC))
498500
idx = diagind(Ā)

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ end
2323
const LU_ROW_MAXIMUM = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum() : Val(true)
2424
const LU_NO_PIVOT = VERSION >= v"1.7.0-DEV.1188" ? NoPivot() : Val(false)
2525

26+
const CHOLESKY_NO_PIVOT = VERSION >= v"1.8.0-rc1" ? NoPivot() : Val(false)
27+
2628
# well-conditioned random n×n matrix with elements of type `T` for testing `eigen`
2729
function rand_eigen(T::Type, n::Int)
2830
# uniform distribution over `(-1, 1)` / `(-1, 1)^2`
@@ -394,7 +396,7 @@ end
394396

395397
@testset "Diagonal" begin
396398
@testset "Diagonal{<:Real}" begin
397-
test_rrule(cholesky, Diagonal([0.3, 0.2, 0.5, 0.6, 0.9]), Val(false))
399+
test_rrule(cholesky, Diagonal([0.3, 0.2, 0.5, 0.6, 0.9]), CHOLESKY_NO_PIVOT)
398400
end
399401
@testset "Diagonal{<:Complex}" begin
400402
# finite differences in general will produce matrices with non-real
@@ -403,26 +405,26 @@ end
403405
D = Diagonal([0.3 + 0im, 0.2, 0.5, 0.6, 0.9])
404406
C = cholesky(D)
405407
test_rrule(
406-
cholesky, D, Val(false);
408+
cholesky, D, CHOLESKY_NO_PIVOT;
407409
output_tangent=Tangent{typeof(C)}(factors=complex(randn(5, 5))),
408410
fkwargs=(; check=false),
409411
)
410412
end
411413
@testset "check has correct default and passed to primal" begin
412-
@test_throws Exception rrule(cholesky, Diagonal(-rand(5)), Val(false))
413-
rrule(cholesky, Diagonal(-rand(5)), Val(false); check=false)
414+
@test_throws Exception rrule(cholesky, Diagonal(-rand(5)), CHOLESKY_NO_PIVOT)
415+
rrule(cholesky, Diagonal(-rand(5)), CHOLESKY_NO_PIVOT; check=false)
414416
end
415417
@testset "failed factorization" begin
416418
A = Diagonal(vcat(rand(4), -rand(4), rand(4)))
417-
test_rrule(cholesky, A, Val(false); fkwargs=(; check=false))
419+
test_rrule(cholesky, A, CHOLESKY_NO_PIVOT; fkwargs=(; check=false))
418420
end
419421
end
420422

421423
@testset "StridedMatrix" begin
422424
@testset "Matrix{$T}" for T in (Float64, ComplexF64)
423425
X = generate_well_conditioned_matrix(T, 10)
424426
V = generate_well_conditioned_matrix(T, 10)
425-
F, dX_pullback = rrule(cholesky, X, Val(false))
427+
F, dX_pullback = rrule(cholesky, X, CHOLESKY_NO_PIVOT)
426428
@testset "uplo=$p, cotangent eltype=$T" for p in [:U, :L], S in unique([T, complex(T)])
427429
Y, dF_pullback = rrule(getproperty, F, p)
428430
= randn(S, size(Y))
@@ -447,22 +449,22 @@ end
447449
@testset "check has correct default and passed to primal" begin
448450
# this will almost certainly be a non-PD matrix
449451
X = Matrix(Symmetric(randn(10, 10)))
450-
@test_throws Exception rrule(cholesky, X, Val(false))
451-
rrule(cholesky, X, Val(false); check=false) # just check it doesn't throw
452+
@test_throws Exception rrule(cholesky, X, CHOLESKY_NO_PIVOT)
453+
rrule(cholesky, X, CHOLESKY_NO_PIVOT; check=false) # just check it doesn't throw
452454
end
453455
end
454456

455457
# Ensure that cotangents of cholesky(::StridedMatrix) and
456458
# (cholesky ∘ Symmetric)(::StridedMatrix) are equal.
457459
@testset "Symmetric" begin
458460
X = generate_well_conditioned_matrix(10)
459-
F, dX_pullback = rrule(cholesky, X, Val(false))
461+
F, dX_pullback = rrule(cholesky, X, CHOLESKY_NO_PIVOT)
460462
ΔU = randn(size(X))
461463
ΔF = Tangent{typeof(F)}(; factors=ΔU)
462464

463465
@testset for uplo in (:L, :U)
464466
X_symmetric, sym_back = rrule(Symmetric, X, uplo)
465-
C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false))
467+
C, chol_back_sym = rrule(cholesky, X_symmetric, CHOLESKY_NO_PIVOT)
466468

467469
ΔC = Tangent{typeof(C)}(; factors=(uplo === :U ? ΔU : ΔU'))
468470
ΔX_symmetric = chol_back_sym(ΔC)[2]
@@ -479,13 +481,13 @@ end
479481
@testset "Hermitian" begin
480482
@testset "Hermitian{$T}" for T in (Float64, ComplexF64)
481483
X = generate_well_conditioned_matrix(T, 10)
482-
F, dX_pullback = rrule(cholesky, X, Val(false))
484+
F, dX_pullback = rrule(cholesky, X, CHOLESKY_NO_PIVOT)
483485
ΔU = randn(T, size(X))
484486
ΔF = Tangent{typeof(F)}(; factors=ΔU)
485487

486488
@testset for uplo in (:L, :U)
487489
X_hermitian, herm_back = rrule(Hermitian, X, uplo)
488-
C, chol_back_herm = rrule(cholesky, X_hermitian, Val(false))
490+
C, chol_back_herm = rrule(cholesky, X_hermitian, CHOLESKY_NO_PIVOT)
489491

490492
ΔC = Tangent{typeof(C)}(; factors=(uplo === :U ? ΔU : ΔU'))
491493
ΔX_hermitian = chol_back_herm(ΔC)[2]
@@ -499,8 +501,8 @@ end
499501
@testset "check has correct default and passed to primal" begin
500502
# this will almost certainly be a non-PD matrix
501503
X = Hermitian(randn(10, 10))
502-
@test_throws Exception rrule(cholesky, X, Val(false))
503-
rrule(cholesky, X, Val(false); check=false)
504+
@test_throws Exception rrule(cholesky, X, CHOLESKY_NO_PIVOT)
505+
rrule(cholesky, X, CHOLESKY_NO_PIVOT; check=false)
504506
end
505507
end
506508

0 commit comments

Comments
 (0)