Skip to content

Commit aab91c6

Browse files
authored
Don't mutate cotangent in rule for exp (#588)
* Don't mutate input cotangent * Increment patch number * Add test for not mutating if cotangent is adjoint
1 parent 8108a77 commit aab91c6

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.26"
3+
version = "1.26.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/matfun.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat})
132132
# Ensures ∂X is mutable. The outer `adjoint` is unwrapped without copy by
133133
# the default _matfun_frechet_adjoint!
134134
ΔX = unthunk(X̄)
135-
∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? ΔX : convert(Matrix, ΔX')'
135+
∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? copy(ΔX) : convert(Matrix, ΔX')'
136136
∂A = _matfun_frechet_adjoint!(exp, ∂X, A, X, intermediates)
137137
return NoTangent(), ∂A
138138
end

test/rulesets/LinearAlgebra/matfun.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@
3535
Y, back = rrule(exp, A)
3636
@maybe_inferred back(rand_tangent(Y))
3737
end
38+
@testset "cotangent not mutated" begin
39+
# https://github.com/JuliaDiff/ChainRules.jl/issues/512
40+
A = [1.0 2.0; 3.0 4.0]
41+
Y, back = rrule(exp, A)
42+
ΔY′ = rand_tangent(Y)'
43+
ΔY′copy = copy(ΔY′)
44+
back(ΔY′)
45+
@test ΔY′ == ΔY′copy
46+
end
3847
@testset "imbalanced A" begin
3948
A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
4049
test_rrule(exp, A; check_inferred=false)

0 commit comments

Comments
 (0)