Skip to content

Commit 2cc27e2

Browse files
sethaxenoxinabox
andauthored
Fix exp rules for some matrices (#596)
* Avoid un-balancing the needed intermediate * Test exp for imbalanced unsquared matrix * Increment patch number * Apply suggestions from code review Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au> * Add exhaustive tests * Eliminate truncation error * Increment patch number * Increment patch version number Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au>
1 parent 7b5f4d1 commit 2cc27e2

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.28.3"
3+
version = "1.28.4"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/matfun.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat}
219219
X *= X
220220
push!(Xpows, X)
221221
end
222+
else
223+
# Xpows[1] must remain balanced for computing the Fréchet derivative
224+
X = copy(X)
222225
end
223226

224227
_unbalance!(X, ilo, ihi, scale, n)
@@ -247,7 +250,7 @@ function _matfun_frechet!(
247250
∂P = copy(∂A2)
248251
∂W = C[4] * ∂P
249252
∂V = C[3] * ∂P
250-
for k in 2:(length(Apows) - 1)
253+
for k in 2:length(Apows)
251254
k2 = 2 * k
252255
P = Apows[k - 1]
253256
∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P

test/rulesets/LinearAlgebra/matfun.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@
1414
A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
1515
test_frule(LinearAlgebra.exp!, A)
1616
end
17+
@testset "imbalanced A with no squaring" begin
18+
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
19+
A = [
20+
-0.007623430669065629 -0.567237096385192 0.4419041897734335;
21+
2.090838913114862 -1.254084243281689 -0.04145771190198238;
22+
2.3397892123412833 -0.6650489083959324 0.6387266010923911
23+
]
24+
test_frule(LinearAlgebra.exp!, A)
25+
end
26+
@testset "exhaustive test" begin
27+
# added to ensure we never hit truncation error
28+
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
29+
rng = MersenneTwister(1)
30+
for _ in 1:100
31+
A = randn(rng, 3, 3)
32+
test_frule(LinearAlgebra.exp!, A)
33+
end
34+
end
1735
@testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
1836
A = Matrix(Hermitian(randn(T, n, n)))
1937
test_frule(LinearAlgebra.exp!, A)
@@ -48,6 +66,24 @@
4866
A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
4967
test_rrule(exp, A; check_inferred=false)
5068
end
69+
@testset "imbalanced A with no squaring" begin
70+
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
71+
A = [
72+
-0.007623430669065629 -0.567237096385192 0.4419041897734335;
73+
2.090838913114862 -1.254084243281689 -0.04145771190198238;
74+
2.3397892123412833 -0.6650489083959324 0.6387266010923911
75+
]
76+
test_rrule(LinearAlgebra.exp, A; check_inferred=false)
77+
end
78+
@testset "exhaustive test" begin
79+
# added to ensure we never hit truncation error
80+
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
81+
rng = MersenneTwister(1)
82+
for _ in 1:100
83+
A = randn(rng, 3, 3)
84+
test_rrule(LinearAlgebra.exp, A; check_inferred=false)
85+
end
86+
end
5187
@testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
5288
A = Matrix(Hermitian(randn(T, n, n)))
5389
test_rrule(exp, A; check_inferred=false)

0 commit comments

Comments
 (0)