Skip to content

Commit 262be3c

Browse files
willtebbuttmcabbottoxinabox
authored
More vect (#496)
* Add generic vect method * Minor version bump * Style fix * Fix test * Improve inferability Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> * Improve comment Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * Update comment * Change comment * Bump minor version Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent d892da5 commit 262be3c

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.6.0"
3+
version = "1.7.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/array.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ function rrule(
3737
return Base.vect(X...), vect_pullback
3838
end
3939

40+
# Data is unmodified, so no need to project.
41+
function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N}
42+
vect_pullback(ȳ) = (NoTangent(), ntuple(n -> ȳ[n], N)...)
43+
return Base.vect(X...), vect_pullback
44+
end
45+
4046
#####
4147
##### `reshape`
4248
#####

test/rulesets/Base/array.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ end
3232
atol=1e-6, rtol=1e-6, check_inferred=VERSION>=v"1.6",
3333
) # tolerance due to Float32.
3434
test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false)
35+
test_rrule(Base.vect, (5.0, 4.0), (y=randn(3),); check_inferred=false)
3536
end
3637
end
3738

0 commit comments

Comments
 (0)