Skip to content

Commit 27328c2

Browse files
committed
more on broadcasting
1 parent 63797fd commit 27328c2

File tree

4 files changed

+35
-20
lines changed

4 files changed

+35
-20
lines changed

src/extra_rules.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,16 @@ function ChainRulesCore.rrule(::typeof(Core.tuple), args...)
147147
end
148148

149149
# TODO: What to do about these integer rules
150-
@ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type)
150+
# @ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type) # now in CR 1.18
151151

152152
ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()
153153

154-
# Skip AD'ing through the axis computation
155-
function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
156-
return Base.Broadcast.instantiate(bc), Δ->begin
157-
Core.tuple(NoTangent(), Δ)
158-
end
159-
end
154+
# # Skip AD'ing through the axis computation
155+
# function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
156+
# return Base.Broadcast.instantiate(bc), Δ->begin
157+
# Core.tuple(NoTangent(), Δ)
158+
# end
159+
# end
160160

161161

162162
using StaticArrays
@@ -268,3 +268,7 @@ end
268268
function ChainRulesCore.rrule(::Type{InplaceableThunk}, add!!, val)
269269
val, Δ->(NoTangent(), NoTangent(), Δ)
270270
end
271+
272+
# ERROR: ArgumentError: Tangent for the primal Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}} should be backed by a AbstractDict type, not by NamedTuple{(:data,), Tuple{ChainRulesCore.ZeroTangent}}.
273+
ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing # solves that!
274+

src/stage1/broadcast.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ end
4444

4545
(::∂⃖{1})(::typeof(broadcasted), f, args...) = split_bc_rule(f, args...)
4646
(::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity
47-
function split_bc_rule(f::F, args...) where {F}
47+
function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
4848
T = Broadcast.combine_eltypes(f, args)
4949
= Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...})
5050
if eltype(T) == Bool
@@ -71,10 +71,11 @@ function split_bc_rule(f::F, args...) where {F}
7171
dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast?
7272
(NoTangent(), NoTangent(), dargs...)
7373
end
74-
return ys, length(args)==1 ? back_2_one : back_2_many
74+
return ys, N==1 ? back_2_one : back_2_many
7575
else
7676
# Slow path: collect all the pullbacks & apply them later.
77-
# Since broadcast makes no guarantee about order, this does not bother to try to reverse it.
77+
# (Since broadcast makes no guarantee about order of calls, and un-fusing
78+
# can change the number of calls, this does not bother to try to reverse.)
7879
_print("path 3")
7980
ys, backs = splitcast(∂⃖{1}(), f, args...)
8081
function back_3(dys)
@@ -84,15 +85,21 @@ function split_bc_rule(f::F, args...) where {F}
8485
dargs = map(unbroadcast, args, Base.tail(deltas)) # no real need to close over args here
8586
(NoTangent(), sum(first(deltas)), dargs...)
8687
end
88+
back_3(::AbstractZero) = (NoTangent(), map(Returns(ZeroTangent()), args)...)
8789
return ys, back_3
8890
end
8991
end
9092

91-
# This uses "mulltimap"-like constructs:
93+
# Skip AD'ing through the axis computation
94+
function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
95+
uninstantiate(Δ) = Core.tuple(NoTangent(), Δ)
96+
return Base.Broadcast.instantiate(bc), uninstantiate
97+
end
98+
99+
# This uses "multimap"-like constructs:
92100

93101
using StructArrays
94102
splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...)))
95-
# warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
96103
splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
97104

98105
#=
@@ -156,9 +163,9 @@ end
156163
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...)
157164
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = split_bc_plus(arg) # ambiguity
158165
function split_bc_plus(xs...) where {F}
159-
broadcasted(+, xs...), Δ -> let Δun = unthunk(Δ)
166+
broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw)
160167
_print("broadcast +")
161-
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δun), xs)...)
168+
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...)
162169
end
163170
end
164171
Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) =
@@ -167,20 +174,20 @@ Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) =
167174
(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ)
168175

169176
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y)
170-
broadcasted(-, x, y), Δ -> let Δun = unthunk(Δ)
177+
broadcasted(-, x, y), Δraw -> let Δ = unthunk(Δraw)
171178
_print("broadcast -")
172-
(NoTangent(), NoTangent(), unbroadcast(x, Δun), -unbroadcast(y, Δun))
179+
(NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ))
173180
# Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
174181
end
175182
end
176183

177184
using LinearAlgebra: dot
178185

179186
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) # should this be vararg, or will laziness handle it?
180-
broadcasted(*, x, y), Δ -> let Δun = unthunk(Δ)
187+
broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw)
181188
_print("broadcast *")
182-
dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y))
183-
dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δun) : unbroadcast(y, Δun .* conj.(x))
189+
dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δ) : unbroadcast(x, Δ .* conj.(y))
190+
dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δ) : unbroadcast(y, Δ .* conj.(x))
184191
# When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
185192
# Will things like this work? Ref([1,2]) .* [1,2,3]
186193
(NoTangent(), NoTangent(), dx, dy)

src/stage1/generated.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ end
4949
Base.getindex(o::OpticBundle, i::Int) = i == 1 ? o.x :
5050
i == 2 ? o.clos :
5151
throw(BoundsError(o, i))
52+
Base.lastindex(o::OpticBundle) = 2
53+
5254
Base.iterate(o::OpticBundle) = (o.x, nothing)
5355
Base.iterate(o::OpticBundle, ::Nothing) = (o.clos, missing)
5456
Base.iterate(o::OpticBundle, ::Missing) = nothing

test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ exp_log(x) = exp(log(x))
231231
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
232232

233233
@test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],)
234-
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,)
234+
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule
235+
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule
236+
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule
235237

236238
@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output
237239
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),)

0 commit comments

Comments
 (0)