Skip to content

Commit 2583c8c

Browse files
committed
improve unbroadcast
1 parent 4b47516 commit 2583c8c

File tree

1 file changed

+35
-7
lines changed

1 file changed

+35
-7
lines changed

src/rulesets/Base/broadcast.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,32 +325,60 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
325325
end
326326
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx
327327

328-
unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx)))
329328
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
330329
val = if length(x) == length(dx)
331330
dx
332331
else
333332
sum(dx; dims=2:ndims(dx))
334333
end
334+
eltype(val) <: AbstractZero && return NoTangent()
335335
return ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
336336
end
337+
unbroadcast(x::Tuple, dx::AbstractZero) = dx
338+
339+
# Scalar types
337340

338-
unbroadcast(f::Function, df) = sum(df)
339341
unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx))
340-
unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx)))
342+
343+
function unbroadcast(x::T, dx) where {T<:Tuple{Any}}
344+
p1 = ProjectTo(only(x))
345+
p1 isa ProjectTo{<:AbstractZero} && return NoTangent()
346+
dx1 = p1(sum(dx))
347+
dx1 isa AbstractZero && return dx1
348+
return Tangent{T}(dx1)
349+
end
350+
unbroadcast(x::Tuple{Any}, dx::AbstractZero) = dx
351+
352+
function unbroadcast(x::Base.RefValue, dx)
353+
p1 = ProjectTo(x.x)
354+
p1 isa ProjectTo{<:AbstractZero} && return NoTangent()
355+
dx1 = p1(sum(dx))
356+
dx1 isa AbstractZero && return dx1
357+
return Tangent{typeof(x)}(; x = dx1)
358+
end
359+
unbroadcast(x::Base.RefValue, dx::AbstractZero) = dx
360+
361+
# Zero types
341362

342363
unbroadcast(::Bool, dx) = NoTangent()
343364
unbroadcast(::AbstractArray{Bool}, dx) = NoTangent()
344365
unbroadcast(::AbstractArray{Bool}, dx::AbstractZero) = dx # ambiguity
345366
unbroadcast(::Val, dx) = NoTangent()
346367

368+
function unbroadcast(f::Function, df)
369+
Base.issingletontype(typeof(f)) && return NoTangent()
370+
return sum(df)
371+
end
372+
373+
# Fallback
374+
347375
function unbroadcast(x, dx)
376+
@info "last unbroadcast method!" x dx
377+
dx isa AbstractZero && return dx
348378
p = ProjectTo(x)
349-
if dx isa AbstractZero || p isa ProjectTo{<:AbstractZero}
379+
if p isa ProjectTo{<:AbstractZero}
350380
return NoTangent()
351-
end
352-
b = Broadcast.broadcastable(x)
353-
if b isa Ref # then x is scalar under broadcast
381+
elseif Broadcast.broadcastable(x) isa Ref # then x is scalar under broadcast
354382
return p(sum(dx))
355383
else
356384
error("don't know how to handle broadcast gradient for x::$(typeof(x))")

0 commit comments

Comments
 (0)