Skip to content

Commit 99c58a9

Browse files
authored
Ensure pullback of exp works for immutable arrays (#381)
* Ensure exp cotangent is mutable * Increment version number * Use convert and inplaceable trait
1 parent 439f482 commit 99c58a9

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.52"
3+
version = "0.7.53"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/matfun.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat})
129129
A = copy(A0)
130130
X, intermediates = _matfun!(exp, A)
131131
function exp_pullback(ΔX)
132-
∂A = _matfun_frechet_adjoint!(exp, ΔX, A, X, intermediates)
132+
# Ensures ∂X is mutable. The outer `adjoint` is unwrapped without copy by
133+
# the default _matfun_frechet_adjoint!
134+
∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? ΔX : convert(Matrix, ΔX')'
135+
∂A = _matfun_frechet_adjoint!(exp, ∂X, A, X, intermediates)
133136
return NO_FIELDS, ∂A
134137
end
135138
return X, exp_pullback

0 commit comments

Comments
 (0)