@@ -325,32 +325,60 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
325
325
end
326
326
unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx:: AbstractZero ) = dx
327
327
328
- unbroadcast (x:: T , dx) where {T<: Tuple{Any} } = ProjectTo (x)(Tangent {T} (sum (dx)))
329
328
function unbroadcast (x:: T , dx) where {T<: Tuple{Vararg{Any,N}} } where {N}
330
329
val = if length (x) == length (dx)
331
330
dx
332
331
else
333
332
sum (dx; dims= 2 : ndims (dx))
334
333
end
334
+ eltype (val) <: AbstractZero && return NoTangent ()
335
335
return ProjectTo (x)(NTuple {length(x)} (val)) # Tangent
336
336
end
337
+ unbroadcast (x:: Tuple , dx:: AbstractZero ) = dx
338
+
339
+ # Scalar types
337
340
338
- unbroadcast (f:: Function , df) = sum (df)
339
341
unbroadcast (x:: Number , dx) = ProjectTo (x)(sum (dx))
340
- unbroadcast (x:: Base.RefValue , dx) = ProjectTo (x)(Ref (sum (dx)))
342
+
343
+ function unbroadcast (x:: T , dx) where {T<: Tuple{Any} }
344
+ p1 = ProjectTo (only (x))
345
+ p1 isa ProjectTo{<: AbstractZero } && return NoTangent ()
346
+ dx1 = p1 (sum (dx))
347
+ dx1 isa AbstractZero && return dx1
348
+ return Tangent {T} (dx1)
349
+ end
350
+ unbroadcast (x:: Tuple{Any} , dx:: AbstractZero ) = dx
351
+
352
+ function unbroadcast (x:: Base.RefValue , dx)
353
+ p1 = ProjectTo (x. x)
354
+ p1 isa ProjectTo{<: AbstractZero } && return NoTangent ()
355
+ dx1 = p1 (sum (dx))
356
+ dx1 isa AbstractZero && return dx1
357
+ return Tangent {typeof(x)} (; x = dx1)
358
+ end
359
+ unbroadcast (x:: Base.RefValue , dx:: AbstractZero ) = dx
360
+
361
+ # Zero types
341
362
342
363
unbroadcast (:: Bool , dx) = NoTangent ()
343
364
unbroadcast (:: AbstractArray{Bool} , dx) = NoTangent ()
344
365
unbroadcast (:: AbstractArray{Bool} , dx:: AbstractZero ) = dx # ambiguity
345
366
unbroadcast (:: Val , dx) = NoTangent ()
346
367
368
+ function unbroadcast (f:: Function , df)
369
+ Base. issingletontype (typeof (f)) && return NoTangent ()
370
+ return sum (df)
371
+ end
372
+
373
+ # Fallback
374
+
347
375
function unbroadcast (x, dx)
376
+ @info " last unbroadcast method!" x dx
377
+ dx isa AbstractZero && return dx
348
378
p = ProjectTo (x)
349
- if dx isa AbstractZero || p isa ProjectTo{<: AbstractZero }
379
+ if p isa ProjectTo{<: AbstractZero }
350
380
return NoTangent ()
351
- end
352
- b = Broadcast. broadcastable (x)
353
- if b isa Ref # then x is scalar under broadcast
381
+ elseif Broadcast. broadcastable (x) isa Ref # then x is scalar under broadcast
354
382
return p (sum (dx))
355
383
else
356
384
error (" don't know how to handle broadcast gradient for x::$(typeof (x)) " )
0 commit comments