Skip to content

Commit 7b5f4d1

Browse files
authored
Improve type stability of cat_pullback (#610)
1 parent c294a9a commit 7b5f4d1

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.28.2"
3+
version = "1.28.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims)
357357
if d in cdims
358358
d > ndimsX ? (prev[d]+1) : (prev[d]+1:prev[d]+sizeX[d])
359359
else
360-
d > ndimsX ? 1 : (:)
360+
d > ndimsX ? 1 : 1:sizeX[d]
361361
end
362362
end
363363
for d in cdims

test/rulesets/Base/array.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ end
233233
test_rrule(cat, rand(1), rand(3, 2, 1); fkwargs=(dims=(1,2),), check_inferred=false) # infers Tuple{Zero, Vector{Float64}, Any}
234234

235235
test_rrule(cat, rand(2, 2), rand(2, 2)'; fkwargs=(dims=1,))
236+
# inference on exotic array types
237+
test_rrule(cat, @SArray(rand(3, 2, 1)), @SArray(rand(3, 2, 1)); fkwargs=(dims=Val(2),))
236238
end
237239

238240
@testset "hvcat" begin

0 commit comments

Comments
 (0)