Skip to content

Commit 8ab56db

Browse files
author
Michael Abbott
committed
a type-stability fix via parent & re-wrap
1 parent df02563 commit 8ab56db

File tree

2 files changed

+54
-11
lines changed

2 files changed

+54
-11
lines changed

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,22 @@ end
108108

109109
function _normp_back_x(x, p, y, Δy)
110110
c = real(Δy) / y
111-
T = promote_type(eltype(x), typeof(c))
112-
∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability.
113-
map!(∂x, x) do xi
111+
∂x = map(x) do xi
114112
a = norm(xi)
115113
∂xi = xi * ((a / y)^(p - 2) * c)
116114
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
117115
end
118116
return ∂x
119117
end
118+
function _normp_back_x(x::WithSomeZeros, p, y, Δy) # Diagonal, UpperTriangular, etc.
119+
c = real(Δy) / y
120+
∂x_data = map(parent(x)) do xi
121+
a = norm(xi)
122+
∂xi = xi * ((a / y)^(p - 2) * c)
123+
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
124+
end
125+
return withsomezeros_rewrap(x, ∂x_data)
126+
end
120127

121128
function _normp_back_p(x, p, y, Δy)
122129
y > 0 && isfinite(y) && !iszero(p) || return zero(real(Δy)) * zero(y) / one(p)
@@ -175,13 +182,13 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray)
175182
end
176183

177184
function _norm1_back(x, y, Δy)
178-
T = promote_type(eltype(x), real(eltype(Δy)))
179-
∂x = similar(x, T)
180-
# The reason not to let broadcast allocate ∂x is that NaN .* Diagonal(ones(3)) isa Matrix,
181-
# while pi .* Diagonal(ones(3)) isa Diagonal, hence this would be type-unstable.
182-
∂x .= sign.(x) .* real(Δy)
185+
∂x = sign.(x) .* real(Δy)
183186
return ∂x
184187
end
188+
function _norm1_back(x::WithSomeZeros, y, Δy)
189+
∂x_data = sign.(parent(x)) .* real(Δy)
190+
return withsomezeros_rewrap(x, ∂x_data)
191+
end
185192
function _norm1_back!(∂x, x, y, Δy)
186193
∂x .+= sign.(x) .* real(Δy)
187194
return ∂x
@@ -211,11 +218,14 @@ function _norm2_forward(x, Δx, y)
211218
return ∂y
212219
end
213220
function _norm2_back(x, y, Δy)
214-
T = typeof(one(eltype(x)) / one(real(eltype(Δy))))
215-
∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability.
216-
∂x .= x .* (real(Δy) * pinv(y))
221+
∂x = x .* (real(Δy) * pinv(y))
217222
return ∂x
218223
end
224+
function _norm2_back(x::WithSomeZeros, y, Δy)
225+
T = typeof(one(eltype(x)) / one(real(eltype(Δy))))
226+
∂x_data = parent(x) .* (real(Δy) * pinv(y))
227+
return withsomezeros_rewrap(x, ∂x_data)
228+
end
219229
function _norm2_back!(∂x, x, y, Δy)
220230
∂x .+= x .* (real(Δy) * pinv(y))
221231
return ∂x # must return after mutating

src/rulesets/LinearAlgebra/utils.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,36 @@ Symmetric
4343
````
4444
"""
4545
_unionall_wrapper(::Type{T}) where {T} = T.name.wrapper
46+
47+
"""
48+
WithSomeZeros{T}
49+
50+
This is a union of LinearAlgebra types, all of which are partly structral zeros,
51+
with a simple backing array given by `parent(x)`. All have methods of `_rewrap`
52+
to re-create.
53+
54+
This exists to solve a type instability, as broadcasting for instance
55+
`λ .* Diagonal(rand(3))` gives a dense matrix when `x==Inf`.
56+
But `withsomezeros_rewrap(x, λ .* parent(x))` is type-stable.
57+
"""
58+
WithSomeZeros{T} = Union{
59+
Diagonal{T},
60+
UpperTriangular{T},
61+
UnitUpperTriangular{T},
62+
UpperHessenberg{T},
63+
LowerTriangular{T},
64+
UnitLowerTriangular{T},
65+
}
66+
for S in [
67+
:Diagonal,
68+
:UpperTriangular,
69+
:UnitUpperTriangular,
70+
:UpperHessenberg,
71+
:LowerTriangular,
72+
:UnitLowerTriangular,
73+
]
74+
@eval withsomezeros_rewrap(::$S, x) = $S(x)
75+
end
76+
77+
# Bidiagonal, Tridiagonal have more complicated storage.
78+
# AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent()))

0 commit comments

Comments
 (0)