@@ -209,10 +209,14 @@ function frule((_, ẋs...), ::typeof(hcat), xs...)
209
209
return hcat (xs... ), hcat (_instantiate_zeros (ẋs, xs)... )
210
210
end
211
211
212
- function rrule (:: typeof (hcat), Xs:: Union{AbstractArray, Number} ...)
212
+ # All the [hv]cat functions treat anything that's not an array as a scalar.
213
+ _catsize (x) = ()
214
+ _catsize (x:: AbstractArray ) = size (x)
215
+
216
+ function rrule (:: typeof (hcat), Xs... )
213
217
Y = hcat (Xs... ) # note that Y always has 1-based indexing, even if X isa OffsetArray
214
218
ndimsY = Val (ndims (Y)) # this avoids closing over Y, Val() is essential for type-stability
215
- sizes = map (size , Xs) # this avoids closing over Xs
219
+ sizes = map (_catsize , Xs) # this avoids closing over Xs
216
220
project_Xs = map (ProjectTo, Xs)
217
221
function hcat_pullback (ȳ)
218
222
dY = unthunk (ȳ)
@@ -279,10 +283,10 @@ function frule((_, ẋs...), ::typeof(vcat), xs...)
279
283
return vcat (xs... ), vcat (_instantiate_zeros (ẋs, xs)... )
280
284
end
281
285
282
- function rrule (:: typeof (vcat), Xs:: Union{AbstractArray, Number} ...)
286
+ function rrule (:: typeof (vcat), Xs... )
283
287
Y = vcat (Xs... )
284
288
ndimsY = Val (ndims (Y))
285
- sizes = map (size , Xs)
289
+ sizes = map (_catsize , Xs)
286
290
project_Xs = map (ProjectTo, Xs)
287
291
function vcat_pullback (ȳ)
288
292
dY = unthunk (ȳ)
@@ -342,11 +346,11 @@ function frule((_, ẋs...), ::typeof(cat), xs...; dims)
342
346
return cat (xs... ; dims), cat (_instantiate_zeros (ẋs, xs)... ; dims)
343
347
end
344
348
345
- function rrule (:: typeof (cat), Xs:: Union{AbstractArray, Number} ...; dims)
349
+ function rrule (:: typeof (cat), Xs... ; dims)
346
350
Y = cat (Xs... ; dims= dims)
347
351
cdims = dims isa Val ? Int (_val (dims)) : dims isa Integer ? Int (dims) : Tuple (dims)
348
352
ndimsY = Val (ndims (Y))
349
- sizes = map (size , Xs)
353
+ sizes = map (_catsize , Xs)
350
354
project_Xs = map (ProjectTo, Xs)
351
355
function cat_pullback (ȳ)
352
356
dY = unthunk (ȳ)
@@ -384,11 +388,11 @@ function frule((_, _, ẋs...), ::typeof(hvcat), rows, xs...)
384
388
return hvcat (rows, xs... ), hvcat (rows, _instantiate_zeros (ẋs, xs)... )
385
389
end
386
390
387
- function rrule (:: typeof (hvcat), rows, values:: Union{AbstractArray, Number} ...)
391
+ function rrule (:: typeof (hvcat), rows, values... )
388
392
Y = hvcat (rows, values... )
389
393
cols = size (Y,2 )
390
394
ndimsY = Val (ndims (Y))
391
- sizes = map (size , values)
395
+ sizes = map (_catsize , values)
392
396
project_Vs = map (ProjectTo, values)
393
397
function hvcat_pullback (dY)
394
398
prev = fill (0 , 2 )
0 commit comments