Skip to content

Commit 8498212

Browse files
committed
many small upgrades
1 parent b031fc1 commit 8498212

File tree

6 files changed

+212
-82
lines changed

6 files changed

+212
-82
lines changed

src/rulesets/Base/base.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex}
7272
return (T(x, y), Complex_pullback)
7373
end
7474

75+
@scalar_rule complex(x) true
76+
7577
# `hypot`
7678

7779
@scalar_rule hypot(x::Real) sign(x)

src/rulesets/Base/broadcast.jl

Lines changed: 93 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
using Base.Broadcast: Broadcast, broadcasted, Broadcasted
22
const RCR = RuleConfig{>:HasReverseMode}
33

4-
rrule(::typeof(copy), bc::Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ)
4+
function rrule(::typeof(copy), bc::Broadcasted)
5+
uncopy(Δ) = (NoTangent(), Δ)
6+
return copy(bc), uncopy
7+
end
58

69
# Skip AD'ing through the axis computation
710
function rrule(::typeof(Broadcast.instantiate), bc::Broadcasted)
8-
uninstantiate(Δ) = Core.tuple(NoTangent(), Δ)
11+
uninstantiate(Δ) = (NoTangent(), Δ)
912
return Broadcast.instantiate(bc), uninstantiate
1013
end
1114

12-
_print(args...) = nothing # println(join(args, " "))
15+
_print(args...) = nothing # println(join(args, " ")) #
1316

1417
#####
1518
##### Split broadcasting
@@ -69,45 +72,37 @@ end
6972

7073
# Don't run broadcasting on scalars
7174
function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Number...) where {F}
72-
# function split_bc_rule(cfg::RCR, f::F, args::Number...) where {F}
73-
_print("split_bc_rule scalar", f)
75+
_print("split_bc_scalar", f)
7476
z, back = rrule_via_ad(cfg, f, args...)
7577
return z, dz -> (NoTangent(), back(dz)...)
7678
end
7779

78-
# using StructArrays
79-
#
80-
# function tuplecast(f::F, args...) where {F}
81-
# T = Broadcast.combine_eltypes(f, args)
82-
# if isconcretetype(T)
83-
# T <: Tuple || throw(ArgumentError("tuplecast(f, args) only works on functions returning a tuple."))
84-
# end
85-
# bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
86-
# StructArrays.components(StructArray(bc))
87-
# end
88-
8980
#####
9081
##### Fused broadcasting
9182
#####
9283

93-
# For certain cheap operations we can easily allow fused broadcast.
94-
# These all have `RuleConfig{>:HasReverseMode}` as otherwise the split rule matches first & they are not used.
95-
# They accept `Broadcasted` because they produce it; it has no eltype but is assumed to contain `Number`s.
84+
# For certain cheap operations we can easily allow fused broadcast; the forward pass may be run twice.
85+
# These all have `RuleConfig{>:HasReverseMode}` only for dispatch, to beat the split rule above.
86+
# Accept `x::Broadcasted` because they produce it; can't dispatch on eltype but `x` is assumed to contain `Number`s.
87+
9688
const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted}
9789

90+
##### Arithmetic: +, -, *, ^2, /
91+
9892
function rrule(::RCR, ::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...)
9993
_print("plus", length(xs))
10094
function bc_plus_back(dy_raw)
10195
dy = unthunk(dy_raw)
102-
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...)
96+
return (NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...) # no copies, this may return dx2 === dx3
10397
end
10498
return broadcasted(+, xs...), bc_plus_back
10599
end
106100

107101
function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast)
108102
_print("minus 2")
109-
bc_minus_back(Δraw) = let Δ = unthunk(Δraw)
110-
(NoTangent(), NoTangent(), @thunk(unbroadcast(x, Δ)), @thunk(-unbroadcast(y, Δ)))
103+
function bc_minus_back(dz_raw)
104+
dz = unthunk(dz_raw)
105+
return (NoTangent(), NoTangent(), @thunk(unbroadcast(x, dz)), @thunk(-unbroadcast(y, dz)))
111106
end
112107
return broadcasted(-, x, y), bc_minus_back
113108
end
@@ -118,46 +113,59 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast)
118113
return broadcasted(-, x), bc_minus_back
119114
end
120115

121-
using LinearAlgebra: dot
122-
123116
function rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast)
124117
_print("times")
125118
function bc_times_back(Δraw)
126119
Δ = unthunk(Δraw)
127-
(NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ))
120+
return (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ))
128121
end
129122
return broadcasted(*, x, y), bc_times_back
130123
end
131-
_back_star(x, y, Δ) = @thunk unbroadcast(x, Δ .* conj.(y))
132-
_back_star(x::Number, y, Δ) = @thunk dot(y, Δ)
124+
_back_star(x, y, Δ) = @thunk unbroadcast(x, Δ .* conj.(y)) # this case probably isn't better than generic
125+
_back_star(x::Number, y, Δ) = @thunk LinearAlgebra.dot(y, Δ) # ... but this is why the rule exists
133126
_back_star(x::Bool, y, Δ) = NoTangent()
134127
_back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x)
135128

136-
# TODO check what happens for A * B * C
129+
#=
130+
# This works, but not sure it improves any benchmarks.
131+
function rrule(cfg::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...)
132+
_print("times", 2 + length(zs))
133+
xy, back1 = rrule(cfg, broadcasted, *, x, y)
134+
xyz, back2 = rrule(cfg, broadcasted, *, xy, zs...)
135+
function bc_times3_back(dxyz)
136+
_, _, dxy, dzs... = back2(dxyz)
137+
_, _, dx, dy = back1(dxy)
138+
return (NoTangent(), NoTangent(), dx, dy, dzs...)
139+
end
140+
xyz, bc_times3_back
141+
end
142+
=#
137143

138144
function rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2})
139145
_print("square")
140146
function bc_square_back(dy_raw)
141147
dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x))
142-
(NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
148+
return (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
143149
end
144150
return broadcasted(Base.literal_pow, ^, x, Val(2)), bc_square_back
145151
end
146152

147153
function rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number)
148154
_print("divide")
149-
z = broadcast(/, x, y)
150-
function bc_divide_back(Δraw)
151-
Δ = unthunk(Δraw)
152-
dx = @thunk unbroadcast(x, Δ ./ conj.(y))
153-
dy = @thunk -dot(z, Δ) / (conj(y)) # the reason to be eager is to allow dot here
155+
# z = broadcast(/, x, y)
156+
z = broadcasted(/, x, y)
157+
function bc_divide_back(dz_raw)
158+
dz = unthunk(dz_raw)
159+
dx = @thunk unbroadcast(x, dz ./ conj.(y))
160+
# dy = @thunk -LinearAlgebra.dot(z, dz) / conj(y) # the reason to be eager is to allow dot here
161+
dy = @thunk -sum(Broadcast.instantiate(broadcasted(*, broadcasted(conj, z), dz))) / conj(y) # complete sum is fast?
154162
(NoTangent(), NoTangent(), dx, dy)
155163
end
156164
return z, bc_divide_back
157165
end
158166

159167
# For the same functions, send accidental broadcasting over numbers directly to `rrule`.
160-
# Could perhaps move all to @scalar_rule?
168+
# (Could perhaps move all to @scalar_rule?)
161169

162170
function _prepend_zero((y, back))
163171
extra_back(dy) = (NoTangent(), back(dy)...)
@@ -172,33 +180,74 @@ rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::
172180
rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero
173181
rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero
174182

175-
# A few more cheap functions
183+
##### Identity, number types
176184

177185
rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(identity, x) |> _prepend_zero
178186
rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity
179187

180-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::NumericOrBroadcast)
181-
bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx)))
182-
return broadcasted(conj, x), bc_conj_back
188+
function rrule(::RCR, ::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number}
189+
_print("bc type", T)
190+
bc_type_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
191+
return broadcasted(T, x), bc_type_back
192+
end
193+
rrule(::RCR, ::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero
194+
195+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast)
196+
_print("bc float")
197+
bc_float_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
198+
return broadcasted(float, x), bc_float_back
183199
end
184-
rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::Number) = rrule(conj, x) |> _prepend_zero
185-
rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero
200+
rrule(::RCR, ::typeof(broadcasted), ::typeof(float), x::Number) = rrule(float, x) |> _prepend_zero
186201

187-
# TODO real, imag
202+
##### Complex: conj, real, imag
203+
204+
for conj in [:conj, :adjoint] # identical as we know eltype <: Number
205+
@eval begin
206+
function rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::NumericOrBroadcast)
207+
bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx)))
208+
return broadcasted($conj, x), bc_conj_back
209+
end
210+
rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::Number) = rrule($conj, x) |> _prepend_zero
211+
rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero
212+
# This `AbstractArray{<:Real}` rule won't catch `conj.(x.+1)` with lazy `.+` rule.
213+
# Could upgrade to infer eltype of the `Broadcasted`?
214+
end
215+
end
216+
217+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast)
218+
_print("real")
219+
bc_real_back(dz) = (NoTangent(), NoTangent(), @thunk(real(unthunk(dz))))
220+
return broadcasted(real, x), bc_real_back
221+
end
222+
rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _prepend_zero
223+
rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero
224+
225+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast)
226+
_print("imag")
227+
bc_imag_back(dz) = (NoTangent(), NoTangent(), @thunk(im .* real.(unthunk(dz))))
228+
return broadcasted(imag, x), bc_imag_back
229+
end
230+
rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero
231+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:Real})
232+
_print("imag(real)")
233+
bc_imag_back_2(dz) = (NoTangent(), NoTangent(), ZeroTangent())
234+
return broadcasted(imag, x), bc_imag_back_2
235+
end
188236

189237
#####
190238
##### Shape fixing
191239
#####
192240

193-
# Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape:
241+
# When sizes disagree, broadcasting gradient uses `unbroadcast` to reduce to correct shape.
242+
# It's sometimes a little wasteful to allocate a too-large `dx`, but difficult to make more efficient.
194243

195244
function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
196245
N = ndims(dx)
197246
if length(x) == length(dx)
198247
ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
199248
else
200249
dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # hack to get type-stable `dims`
201-
ProjectTo(x)(sum(dx; dims)) # ideally this sum might be thunked?
250+
ProjectTo(x)(sum(dx; dims))
202251
end
203252
end
204253
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx

src/tuplecast.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@
44
55
For a function `f` which returns a tuple, this is `== unzip(broadcast(f, args...))`,
66
but performed using `StructArrays` for efficiency.
7+
8+
# Examples
9+
```
10+
julia> using ChainRules: tuplecast, unzip
11+
12+
julia> tuplecast(x -> (x,2x), 1:3)
13+
([1, 2, 3], [2, 4, 6])
14+
15+
julia> mats = @btime tuplecast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB
16+
min 1.776 ms, mean 20.421 ms (4 allocations, 15.26 MiB)
17+
18+
julia> mats == @btime unzip(broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000))) # intermediate matrix of tuples
19+
min 2.660 ms, mean 40.007 ms (6 allocations, 30.52 MiB)
20+
true
21+
```
722
"""
823
function tuplecast(f::F, args...) where {F}
924
T = Broadcast.combine_eltypes(f, args)
@@ -67,6 +82,21 @@ end
6782
6883
Converts an array of tuples into a tuple of arrays.
6984
Eager. Will work by `reinterpret` when possible.
85+
86+
```jldoctest
87+
julia> ChainRules.unzip([(1,2), (3,4), (5,6)]) # makes two new Arrays:
88+
([1, 3, 5], [2, 4, 6])
89+
90+
julia> typeof(ans)
91+
Tuple{Vector{Int64}, Vector{Int64}}
92+
93+
julia> ChainRules.unzip([(1,nothing) (3,nothing) (5,nothing)]) # this can reinterpret:
94+
([1 3 5], [nothing nothing nothing])
95+
96+
julia> ans[1]
97+
1×3 reinterpret(Int64, ::Matrix{Tuple{Int64, Nothing}}):
98+
1 3 5
99+
```
70100
"""
71101
function unzip(xs::AbstractArray)
72102
x1 = first(xs)

test/rulesets/Base/base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im)
7878
test_scalar(real, x)
7979
test_scalar(imag, x)
80+
test_scalar(complex, x)
8081
test_scalar(hypot, x)
8182
test_scalar(adjoint, x)
8283
end

0 commit comments

Comments
 (0)