Skip to content

Commit de398ae

Browse files
authored
Merge pull request #785 from JuliaDiff/ox/izt
support arbitary types for vect -- use zero_tangent for _instantiate_zeros
2 parents c69ee3e + c128d8e commit de398ae

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

src/rulesets/Base/array.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131

3232
@non_differentiable Base.vect()
3333

34-
function frule((_, ẋs...), ::typeof(Base.vect), xs::Number...)
34+
function frule((_, ẋs...), ::typeof(Base.vect), xs...)
3535
return Base.vect(xs...), Base.vect(_instantiate_zeros(ẋs, xs)...)
3636
end
3737

@@ -71,9 +71,7 @@ materialises each zero `ẋ` to be `zero(x)`.
7171
"""
7272
_instantiate_zeros(ẋs, xs) = map(_i_zero, ẋs, xs)
7373
_i_zero(ẋ, x) =
74-
_i_zero(ẋ::AbstractZero, x) = zero(x)
75-
# Possibly this won't work for partly non-diff arrays, something like `gradient(x -> ["abc", x][end], 1)`
76-
# may give a MethodError for `zero` but won't be wrong.
74+
_i_zero(ẋ::AbstractZero, x) = zero_tangent(x)
7775

7876
# Fast paths. Should it also collapse all-Zero cases?
7977
_instantiate_zeros(ẋs::Tuple{Vararg{Number}}, xs) = ẋs

test/rulesets/Base/array.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,16 @@ end
3535
@testset "vect" begin
3636
test_rrule(Base.vect)
3737
@testset "homogeneous type" begin
38-
test_rrule(Base.vect, (5.0, ), (4.0, ))
38+
test_rrule(Base.vect, (5.0,), (4.0,))
39+
test_frule(Base.vect, (5.0,), (4.0,))
3940
test_rrule(Base.vect, 5.0, 4.0, 3.0)
41+
test_frule(Base.vect, 5.0, 4.0, 3.0)
4042
test_rrule(Base.vect, randn(2, 2), randn(3, 3))
43+
test_frule(Base.vect, randn(2, 2), randn(3, 3))
44+
45+
# Nonnumber types
46+
test_frule(Base.vect, (1.0, 2.0), (1.0, 2.0))
47+
test_rrule(Base.vect, (1.0, 2.0), (1.0, 2.0))
4148
end
4249
@testset "inhomogeneous type" begin
4350
# fwd
@@ -52,7 +59,7 @@ end
5259
end
5360
@testset "_instantiate_zeros" begin
5461
# This is an internal function also used for `cat` etc.
55-
@eval using ChainRules: _instantiate_zeros
62+
_instantiate_zeros = ChainRules._instantiate_zeros
5663
# Check these hit the fast path, unrealistic input so that map would fail:
5764
@test _instantiate_zeros((true, 2 , 3.0), ()) == (1, 2, 3)
5865
@test _instantiate_zeros((1:2, [3, 4]), ()) == (1:2, 3:4)

0 commit comments

Comments
 (0)