Skip to content

Commit 1b7307f

Browse files
author
Michael Abbott
committed
don't let broadcast allocate ∂x because it's weird
1 parent 3ffd391 commit 1b7307f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray)
173173
end
174174

175175
function _norm1_back(x, y, Δy)
176-
∂x = sign.(x) .* real(Δy)
176+
T = promote_type(eltype(x), real(eltype(Δy)))
177+
∂x = similar(x, T)
178+
# The reason not to let broadcast allocate ∂x is that NaN .* Diagonal(ones(3)) isa Matrix,
179+
# while pi .* Diagonal(ones(3)) isa Diagonal, hence this would be type-unstable.
180+
∂x .= sign.(x) .* real(Δy)
177181
return ∂x
178182
end
179183
function _norm1_back!(∂x, x, y, Δy)
@@ -205,7 +209,9 @@ function _norm2_forward(x, Δx, y)
205209
return ∂y
206210
end
207211
function _norm2_back(x, y, Δy)
208-
∂x = x .* (real(Δy) * pinv(y))
212+
T = typeof(one(eltype(x)) / one(real(eltype(Δy))))
213+
∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability.
214+
∂x .= x .* (real(Δy) * pinv(y))
209215
return ∂x
210216
end
211217
function _norm2_back!(∂x, x, y, Δy)

0 commit comments

Comments
 (0)