Skip to content

Commit 86a2d9e

Browse files
committed
Make vect work with non-Number inputs
1 parent 08cfc36 commit 86a2d9e

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/rulesets/Base/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131

3232
@non_differentiable Base.vect()
3333

34-
function frule((_, ẋs...), ::typeof(Base.vect), xs::Number...)
34+
function frule((_, ẋs...), ::typeof(Base.vect), xs...)
3535
return Base.vect(xs...), Base.vect(_instantiate_zeros(ẋs, xs)...)
3636
end
3737

test/rulesets/Base/array.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,15 @@ end
3636
test_rrule(Base.vect)
3737
@testset "homogeneous type" begin
3838
test_rrule(Base.vect, (5.0, ), (4.0, ))
39+
test_frule(Base.vect, (5.0, ), (4.0, ))
3940
test_rrule(Base.vect, 5.0, 4.0, 3.0)
41+
test_frule(Base.vect, 5.0, 4.0, 3.0)
4042
test_rrule(Base.vect, randn(2, 2), randn(3, 3))
43+
test_frule(Base.vect, randn(2, 2), randn(3, 3))
44+
45+
# Nonnumber types
46+
test_frule(Base.vect, (1.0, 2.0), (1.0, 2.0))
47+
test_rrule(Base.vect, (1.0, 2.0), (1.0, 2.0))
4148
end
4249
@testset "inhomogeneous type" begin
4350
# fwd
@@ -52,7 +59,7 @@ end
5259
end
5360
@testset "_instantiate_zeros" begin
5461
# This is an internal function also used for `cat` etc.
55-
@eval using ChainRules: _instantiate_zeros
62+
_instantiate_zeros = ChainRules._instantiate_zeros
5663
# Check these hit the fast path, unrealistic input so that map would fail:
5764
@test _instantiate_zeros((true, 2 , 3.0), ()) == (1, 2, 3)
5865
@test _instantiate_zeros((1:2, [3, 4]), ()) == (1:2, 3:4)

0 commit comments

Comments
 (0)