Skip to content

Commit b031fc1

Browse files
committed
broadcasting, adapted from Diffractor PR68
1 parent d0bcfc5 commit b031fc1

File tree

8 files changed

+530
-5
lines changed

8 files changed

+530
-5
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
14+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1415

1516
[compat]
1617
ChainRulesCore = "1.12"

src/ChainRules.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Random
1010
using RealDot: realdot
1111
using SparseArrays
1212
using Statistics
13+
using StructArrays
1314

1415
# Basically everything this package does is overloading these, so we make an exception
1516
# to the normal rule of only overload via `ChainRulesCore.rrule`.
@@ -21,6 +22,9 @@ using ChainRulesCore: derivatives_given_output
2122
# numbers that we know commute under multiplication
2223
const CommutativeMulNumber = Union{Real,Complex}
2324

25+
# StructArrays
26+
include("tuplecast.jl")
27+
2428
include("rulesets/Core/core.jl")
2529

2630
include("rulesets/Base/utils.jl")
@@ -33,6 +37,7 @@ include("rulesets/Base/arraymath.jl")
3337
include("rulesets/Base/indexing.jl")
3438
include("rulesets/Base/sort.jl")
3539
include("rulesets/Base/mapreduce.jl")
40+
include("rulesets/Base/broadcast.jl")
3641

3742
include("rulesets/Statistics/statistics.jl")
3843

src/rulesets/Base/broadcast.jl

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
using Base.Broadcast: Broadcast, broadcasted, Broadcasted
2+
const RCR = RuleConfig{>:HasReverseMode}
3+
4+
rrule(::typeof(copy), bc::Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ)
5+
6+
# Skip AD'ing through the axis computation
7+
function rrule(::typeof(Broadcast.instantiate), bc::Broadcasted)
8+
uninstantiate(Δ) = Core.tuple(NoTangent(), Δ)
9+
return Broadcast.instantiate(bc), uninstantiate
10+
end
11+
12+
_print(args...) = nothing # println(join(args, " "))
13+
14+
#####
15+
##### Split broadcasting
16+
#####
17+
18+
function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
19+
# = split_bc_rule(cfg, f, args...)
20+
# function split_bc_rule(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
21+
T = Broadcast.combine_eltypes(f, args)
22+
= Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...})
23+
if T === Bool
24+
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
25+
_print("split_bc_rule 1 ", f)
26+
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
27+
return f.(args...), back_1
28+
elseif T <: Number && isconcretetype(TΔ)
29+
# 2: Fast path: just broadcast, and use arguments & result to find derivatives.
30+
_print("split_bc_rule 2", f, N)
31+
ys = f.(args...)
32+
function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all
33+
delta = broadcast(unthunk(dys), ys, args...) do dy, y, a
34+
das = only(derivatives_given_output(y, f, a))
35+
dy * conj(only(das)) # possibly this * should be made nan-safe.
36+
end
37+
(NoTangent(), NoTangent(), ProjectTo(only(args))(delta))
38+
end
39+
back_2_one(z::AbstractZero) = (NoTangent(), NoTangent(), z)
40+
function back_2_many(dys)
41+
deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as...
42+
das = only(derivatives_given_output(y, f, as...))
43+
map(da -> dy * conj(da), das)
44+
end
45+
dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast?
46+
(NoTangent(), NoTangent(), dargs...)
47+
end
48+
back_2_many(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...)
49+
return ys, N==1 ? back_2_one : back_2_many
50+
else
51+
_print("split_bc_rule 3", f, N)
52+
# 3: Slow path: collect all the pullbacks & apply them later.
53+
# (Since broadcast makes no guarantee about order of calls, and un-fusing
54+
# can change the number of calls, don't bother to try to reverse the iteration.)
55+
ys3, backs = tuplecast(args...) do a...
56+
rrule_via_ad(cfg, f, a...)
57+
end
58+
function back_3(dys)
59+
deltas = tuplecast(backs, unthunk(dys)) do back, dy # could be map, sizes match
60+
map(unthunk, back(dy))
61+
end
62+
dargs = map(unbroadcast, args, Base.tail(deltas))
63+
(NoTangent(), ProjectTo(f)(sum(first(deltas))), dargs...)
64+
end
65+
back_3(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...)
66+
return ys3, back_3
67+
end
68+
end
69+
70+
# Don't run broadcasting on scalars
71+
function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Number...) where {F}
72+
# function split_bc_rule(cfg::RCR, f::F, args::Number...) where {F}
73+
_print("split_bc_rule scalar", f)
74+
z, back = rrule_via_ad(cfg, f, args...)
75+
return z, dz -> (NoTangent(), back(dz)...)
76+
end
77+
78+
# using StructArrays
79+
#
80+
# function tuplecast(f::F, args...) where {F}
81+
# T = Broadcast.combine_eltypes(f, args)
82+
# if isconcretetype(T)
83+
# T <: Tuple || throw(ArgumentError("tuplecast(f, args) only works on functions returning a tuple."))
84+
# end
85+
# bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
86+
# StructArrays.components(StructArray(bc))
87+
# end
88+
89+
#####
90+
##### Fused broadcasting
91+
#####
92+
93+
# For certain cheap operations we can easily allow fused broadcast.
94+
# These all have `RuleConfig{>:HasReverseMode}` as otherwise the split rule matches first & they are not used.
95+
# They accept `Broadcasted` because they produce it; it has no eltype but is assumed to contain `Number`s.
96+
const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted}
97+
98+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...)
99+
_print("plus", length(xs))
100+
function bc_plus_back(dy_raw)
101+
dy = unthunk(dy_raw)
102+
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...)
103+
end
104+
return broadcasted(+, xs...), bc_plus_back
105+
end
106+
107+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast)
108+
_print("minus 2")
109+
bc_minus_back(Δraw) = let Δ = unthunk(Δraw)
110+
(NoTangent(), NoTangent(), @thunk(unbroadcast(x, Δ)), @thunk(-unbroadcast(y, Δ)))
111+
end
112+
return broadcasted(-, x, y), bc_minus_back
113+
end
114+
115+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast)
116+
_print("minus 1")
117+
bc_minus_back(dy) = (NoTangent(), NoTangent(), @thunk -unthunk(dy))
118+
return broadcasted(-, x), bc_minus_back
119+
end
120+
121+
using LinearAlgebra: dot
122+
123+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast)
124+
_print("times")
125+
function bc_times_back(Δraw)
126+
Δ = unthunk(Δraw)
127+
(NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ))
128+
end
129+
return broadcasted(*, x, y), bc_times_back
130+
end
131+
_back_star(x, y, Δ) = @thunk unbroadcast(x, Δ .* conj.(y))
132+
_back_star(x::Number, y, Δ) = @thunk dot(y, Δ)
133+
_back_star(x::Bool, y, Δ) = NoTangent()
134+
_back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x)
135+
136+
# TODO check what happens for A * B * C
137+
138+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2})
139+
_print("square")
140+
function bc_square_back(dy_raw)
141+
dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x))
142+
(NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
143+
end
144+
return broadcasted(Base.literal_pow, ^, x, Val(2)), bc_square_back
145+
end
146+
147+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number)
148+
_print("divide")
149+
z = broadcast(/, x, y)
150+
function bc_divide_back(Δraw)
151+
Δ = unthunk(Δraw)
152+
dx = @thunk unbroadcast(x, Δ ./ conj.(y))
153+
dy = @thunk -dot(z, Δ) / (conj(y)) # the reason to be eager is to allow dot here
154+
(NoTangent(), NoTangent(), dx, dy)
155+
end
156+
return z, bc_divide_back
157+
end
158+
159+
# For the same functions, send accidental broadcasting over numbers directly to `rrule`.
160+
# Could perhaps move all to @scalar_rule?
161+
162+
function _prepend_zero((y, back))
163+
extra_back(dy) = (NoTangent(), back(dy)...)
164+
return y, extra_back
165+
end
166+
167+
rrule(::RCR, ::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, args...) |> _prepend_zero
168+
rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero
169+
rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero
170+
rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero
171+
rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) =
172+
rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero
173+
rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero
174+
175+
# A few more cheap functions
176+
177+
rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(identity, x) |> _prepend_zero
178+
rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity
179+
180+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::NumericOrBroadcast)
181+
bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx)))
182+
return broadcasted(conj, x), bc_conj_back
183+
end
184+
rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::Number) = rrule(conj, x) |> _prepend_zero
185+
rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero
186+
187+
# TODO real, imag
188+
189+
#####
190+
##### Shape fixing
191+
#####
192+
193+
# Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape:
194+
195+
function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
196+
N = ndims(dx)
197+
if length(x) == length(dx)
198+
ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
199+
else
200+
dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # hack to get type-stable `dims`
201+
ProjectTo(x)(sum(dx; dims)) # ideally this sum might be thunked?
202+
end
203+
end
204+
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx
205+
206+
unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx)))
207+
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
208+
val = if length(x) == length(dx)
209+
dx
210+
else
211+
sum(dx; dims=2:ndims(dx))
212+
end
213+
ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
214+
end
215+
216+
unbroadcast(f::Function, df) = sum(df)
217+
unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx))
218+
unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx)))
219+
220+
unbroadcast(::Bool, dx) = NoTangent()
221+
unbroadcast(::AbstractArray{Bool}, dx) = NoTangent()
222+
unbroadcast(::AbstractArray{Bool}, dx::AbstractZero) = dx # ambiguity
223+
unbroadcast(::Val, dx) = NoTangent()
224+
225+
function unbroadcast(x, dx)
226+
p = ProjectTo(x)
227+
if dx isa AbstractZero || p isa ProjectTo{<:AbstractZero}
228+
return NoTangent()
229+
end
230+
b = Broadcast.broadcastable(x)
231+
if b isa Ref # then x is scalar under broadcast
232+
return p(sum(dx))
233+
else
234+
error("don't know how to handle broadcast gradient for x::$(typeof(x))")
235+
end
236+
end
237+
238+
#####
239+
##### For testing
240+
#####
241+
242+
function rrule(cfg::RCR, ::typeof(copybroadcasted), f, args...)
243+
y, back = rrule(cfg, broadcasted, f, args...)
244+
return _maybe_copy(y), back
245+
end
246+
247+
_maybe_copy(y) = copy(y)
248+
_maybe_copy(y::Tuple) = y

src/rulesets/Base/fastmath_able.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,16 @@ let
167167
@scalar_rule x + y (true, true)
168168
@scalar_rule x - y (true, -1)
169169
@scalar_rule x / y (one(x) / y, -/ y))
170+
171+
## many-arg +
172+
function frule((_, Δx, Δy...), ::typeof(+), x::Number, ys::Number...)
173+
+(x, ys...), +(Δx, Δy...)
174+
end
175+
176+
function rrule(::typeof(+), x::Number, ys::Number...)
177+
plus_back(dz) = (NoTangent(), dz, map(Returns(dz), ys)...)
178+
+(x, ys...), plus_back
179+
end
170180

171181
## power
172182
# literal_pow is in base.jl
@@ -276,6 +286,10 @@ let
276286
return Ω4, times_pullback4
277287
end
278288
rrule(::typeof(*), x::Number) = rrule(identity, x)
289+
290+
# This is used to choose a faster path in some broadcasting operations:
291+
ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number) = tuple((y', x'))
292+
ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number) = tuple((y'z', x'z', x'y'))
279293
end # fastable_ast
280294

281295
# Rewrite everything to use fast_math functions, including the type-constraints
@@ -288,12 +302,12 @@ let
288302
non_transformed_definitions = intersect(fastable_ast.args, fast_ast.args)
289303
filter!(expr->!(expr isa LineNumberNode), non_transformed_definitions)
290304
if !isempty(non_transformed_definitions)
291-
error(
292-
"Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n" *
293-
join(non_transformed_definitions, "\n")
294-
)
305+
# error(
306+
# "Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n" *
307+
# join(non_transformed_definitions, "\n")
308+
# )
295309
# This error() may not play well with Revise. But a wanring @error does:
296-
# @error "Non-FastMath compatible rules defined in fastmath_able.jl." non_transformed_definitions
310+
@error "Non-FastMath compatible rules defined in fastmath_able.jl." non_transformed_definitions
297311
end
298312

299313
eval(fast_ast)

0 commit comments

Comments
 (0)