Skip to content

Commit 3b3791f

Browse files
authored
Widen cat rules (#614)
* widen cat rules * cat+version
1 parent ffefa07 commit 3b3791f

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.29.0"
3+
version = "1.30.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/array.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,14 @@ function frule((_, ẋs...), ::typeof(hcat), xs...)
209209
return hcat(xs...), hcat(_instantiate_zeros(ẋs, xs)...)
210210
end
211211

212-
function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
212+
# All the [hv]cat functions treat anything that's not an array as a scalar.
213+
_catsize(x) = ()
214+
_catsize(x::AbstractArray) = size(x)
215+
216+
function rrule(::typeof(hcat), Xs...)
213217
Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray
214218
ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability
215-
sizes = map(size, Xs) # this avoids closing over Xs
219+
sizes = map(_catsize, Xs) # this avoids closing over Xs
216220
project_Xs = map(ProjectTo, Xs)
217221
function hcat_pullback(ȳ)
218222
dY = unthunk(ȳ)
@@ -279,10 +283,10 @@ function frule((_, ẋs...), ::typeof(vcat), xs...)
279283
return vcat(xs...), vcat(_instantiate_zeros(ẋs, xs)...)
280284
end
281285

282-
function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
286+
function rrule(::typeof(vcat), Xs...)
283287
Y = vcat(Xs...)
284288
ndimsY = Val(ndims(Y))
285-
sizes = map(size, Xs)
289+
sizes = map(_catsize, Xs)
286290
project_Xs = map(ProjectTo, Xs)
287291
function vcat_pullback(ȳ)
288292
dY = unthunk(ȳ)
@@ -342,11 +346,11 @@ function frule((_, ẋs...), ::typeof(cat), xs...; dims)
342346
return cat(xs...; dims), cat(_instantiate_zeros(ẋs, xs)...; dims)
343347
end
344348

345-
function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims)
349+
function rrule(::typeof(cat), Xs...; dims)
346350
Y = cat(Xs...; dims=dims)
347351
cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims)
348352
ndimsY = Val(ndims(Y))
349-
sizes = map(size, Xs)
353+
sizes = map(_catsize, Xs)
350354
project_Xs = map(ProjectTo, Xs)
351355
function cat_pullback(ȳ)
352356
dY = unthunk(ȳ)
@@ -384,11 +388,11 @@ function frule((_, _, ẋs...), ::typeof(hvcat), rows, xs...)
384388
return hvcat(rows, xs...), hvcat(rows, _instantiate_zeros(ẋs, xs)...)
385389
end
386390

387-
function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...)
391+
function rrule(::typeof(hvcat), rows, values...)
388392
Y = hvcat(rows, values...)
389393
cols = size(Y,2)
390394
ndimsY = Val(ndims(Y))
391-
sizes = map(size, values)
395+
sizes = map(_catsize, values)
392396
project_Vs = map(ProjectTo, values)
393397
function hvcat_pullback(dY)
394398
prev = fill(0, 2)

test/rulesets/Base/array.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ end
166166
test_rrule(hcat, rand(3,1,1,2), rand(3,3,1,2))
167167

168168
# mix types
169-
test_rrule(hcat, rand(2, 2), rand(2, 2)')
169+
test_rrule(hcat, rand(1, 3), rand(2)')
170+
test_rrule(hcat, rand(1), (nothing, rand()), check_inferred=false)
170171
end
171172

172173
@testset "reduce hcat" begin
@@ -203,9 +204,10 @@ end
203204
test_rrule(vcat, rand(), rand(3), rand(3,1,1))
204205
test_rrule(vcat, rand(3,1,2), rand(4,1,2))
205206

206-
207207
# mix types
208208
test_rrule(vcat, rand(2, 2), rand(2, 2)')
209+
test_rrule(vcat, rand(), rand() => rand(); check_inferred=false)
210+
test_rrule(vcat, rand(3), (rand(), nothing), pi/2; check_inferred=false)
209211
end
210212

211213
@testset "reduce vcat" begin
@@ -235,6 +237,7 @@ end
235237
test_rrule(cat, rand(2, 2), rand(2, 2)'; fkwargs=(dims=1,))
236238
# inference on exotic array types
237239
test_rrule(cat, @SArray(rand(3, 2, 1)), @SArray(rand(3, 2, 1)); fkwargs=(dims=Val(2),))
240+
test_rrule(cat, pi/2, rand(1,3), (4.5,); fkwargs=(;dims=(2,)), check_inferred=false)
238241
end
239242

240243
@testset "hvcat" begin
@@ -249,6 +252,7 @@ end
249252

250253
# mix types (adjoint and transpose)
251254
test_rrule(hvcat, 1, rand(3)', transpose(rand(3)) rand(1,3))
255+
test_rrule(hvcat, (1,2), rand(2)', (3.4, 5.6), 7.8; check_inferred=false)
252256
end
253257

254258
@testset "reverse" begin

0 commit comments

Comments
 (0)