Skip to content

Commit 721d89b

Browse files
authored
Fix various type-instabilities (#329)
* Resolve instability in ifelse * Resolve instability in broadcasting over structured matrices * Generate DNE returns type-stably * Move Zero check out of pullback * Get unionall type-stably * Revert "Move Zero check out of pullback" This reverts commit d41ef75. * Remove misplaced extern * Rename and document _unionall_typeof * Increment patch version number
1 parent abfe271 commit 721d89b

File tree

9 files changed

+43
-17
lines changed

9 files changed

+43
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.39"
3+
version = "0.7.40"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/array.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Int...)
1414
A_dims = size(A)
1515
function reshape_pullback(Ȳ)
1616
∂A = reshape(Ȳ, A_dims)
17-
return (NO_FIELDS, ∂A, fill(DoesNotExist(), length(dims))...)
17+
∂dims = broadcast(_ -> DoesNotExist(), dims)
18+
return (NO_FIELDS, ∂A, ∂dims...)
1819
end
1920
return reshape(A, dims...), reshape_pullback
2021
end

src/rulesets/Base/fastmath_able.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ let
6666
## abs
6767
function frule((_, Δx), ::typeof(abs), x::Union{Real, Complex})
6868
Ω = abs(x)
69-
signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω)
7069
# `ifelse` is applied only to denominator to ensure type-stability.
70+
signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω)
7171
return Ω, _realconjtimes(signx, Δx)
7272
end
7373

@@ -108,7 +108,8 @@ let
108108
function frule((_, Δx), ::typeof(angle), x)
109109
Ω = angle(x)
110110
# `ifelse` is applied only to denominator to ensure type-stability.
111-
∂Ω = _imagconjtimes(x, Δx) / ifelse(iszero(x), one(x), abs2(x))
111+
n = ifelse(iszero(x), one(real(x)), abs2(x))
112+
∂Ω = _imagconjtimes(x, Δx) / n
112113
return Ω, ∂Ω
113114
end
114115

@@ -127,8 +128,9 @@ let
127128
function angle_pullback(ΔΩ)
128129
x, y = reim(z)
129130
Δu, Δv = reim(ΔΩ)
130-
return (NO_FIELDS, (-y + im*x)*Δu/ifelse(iszero(z), one(z), abs2(z)))
131131
# `ifelse` is applied only to denominator to ensure type-stability.
132+
n = ifelse(iszero(z), one(real(z)), abs2(z))
133+
return (NO_FIELDS, (-y + im*x)*Δu/n)
132134
end
133135
return angle(z), angle_pullback
134136
end
@@ -185,14 +187,14 @@ let
185187
# `sign`
186188

187189
function frule((_, Δx), ::typeof(sign), x)
188-
n = ifelse(iszero(x), one(x), abs(x))
190+
n = ifelse(iszero(x), one(real(x)), abs(x))
189191
Ω = x isa Real ? sign(x) : x / n
190192
∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im
191193
return Ω, ∂Ω
192194
end
193195

194196
function rrule(::typeof(sign), x)
195-
n = ifelse(iszero(x), one(x), abs(x))
197+
n = ifelse(iszero(x), one(real(x)), abs(x))
196198
Ω = x isa Real ? sign(x) : x / n
197199
function sign_pullback(ΔΩ)
198200
∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im

src/rulesets/Base/indexing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...)
2020
@thunk(getindex_add!(zero(x))),
2121
getindex_add!
2222
)
23-
return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...)
23+
īnds = broadcast(_ -> DoesNotExist(), inds)
24+
return (NO_FIELDS, x̄, īnds...)
2425
end
2526

2627
return y, getindex_pullback

src/rulesets/LinearAlgebra/blas.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy)
2222
∂X = Zero()
2323
∂Y = Zero()
2424
else
25-
ΔΩ = extern(ΔΩ)
2625
∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx)
2726
∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy)
2827
end

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ end
111111

112112
function _normp_back_x(x, p, y, Δy)
113113
c = real(Δy) / y
114-
∂x = broadcast(x) do xi
114+
∂x = similar(x)
115+
broadcast!(∂x, x) do xi
115116
a = norm(xi)
116117
∂xi = xi * ((a / y)^(p - 2) * c)
117118
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
@@ -181,7 +182,11 @@ function rrule(
181182
return y, norm1_pullback
182183
end
183184

184-
_norm1_back(x, y, Δy) = sign.(x) .* real(Δy)
185+
function _norm1_back(x, y, Δy)
186+
∂x = similar(x)
187+
∂x .= sign.(x) .* real(Δy)
188+
return ∂x
189+
end
185190

186191
#####
187192
##### `norm2`
@@ -206,7 +211,11 @@ function _norm2_forward(x, Δx, y)
206211
∂y = real(dot(x, Δx)) * pinv(y)
207212
return ∂y
208213
end
209-
_norm2_back(x, y, Δy) = x .* (real(Δy) * pinv(y))
214+
function _norm2_back(x, y, Δy)
215+
∂x = similar(x)
216+
∂x .= x .* (real(Δy) * pinv(y))
217+
return ∂x
218+
end
210219

211220
#####
212221
##### `normalize`

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T}, AbstractTriangular{T}}
88
function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real}
99
Y = A / B
1010
function slash_pullback(Ȳ)
11-
S = T.name.wrapper
1211
∂A = @thunk/ B'
13-
∂B = @thunk S(-Y' * (Ȳ / B'))
12+
∂B = @thunk _unionall_wrapper(T)(-Y' * (Ȳ / B'))
1413
return (NO_FIELDS, ∂A, ∂B)
1514
end
1615
return Y, slash_pullback
@@ -19,8 +18,7 @@ end
1918
function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real}
2019
Y = A \ B
2120
function backslash_pullback(Ȳ)
22-
S = T.name.wrapper
23-
∂A = @thunk S(-(A' \ Ȳ) * Y')
21+
∂A = @thunk _unionall_wrapper(T)(-(A' \ Ȳ) * Y')
2422
∂B = @thunk A' \
2523
return NO_FIELDS, ∂A, ∂B
2624
end

src/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99
function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
1010
Ω = T(A, uplo)
1111
function HermOrSym_pullback(ΔΩ)
12-
return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist())
12+
return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, Ω.uplo), DoesNotExist())
1313
end
1414
return Ω, HermOrSym_pullback
1515
end

src/rulesets/LinearAlgebra/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,19 @@ function _eyesubx!(X::AbstractMatrix)
2727
end
2828

2929
_extract_imag(x) = complex(0, imag(x))
30+
31+
"""
32+
_unionall_wrapper(T::Type) -> UnionAll
33+
34+
Return the most general `UnionAll` type union associated with the concrete type `T`.
35+
36+
# Example
37+
```julia
38+
julia> _unionall_wrapper(typeof(Diagonal(1:3)))
39+
Diagonal
40+
41+
julia> _unionall_wrapper(typeof(Symmetric(randn(3, 3))))
42+
Symmetric
43+
````
44+
"""
45+
_unionall_wrapper(::Type{T}) where {T} = T.name.wrapper

0 commit comments

Comments
 (0)