Skip to content

Commit 79c2991

Browse files
authored
Handle the rrule for SVector{}(....) (#1226)
* Handle the rrule for SVector{}(....) * Bump version
1 parent 91f4857 commit 79c2991

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.8.0"
3+
version = "1.8.1"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

ext/StaticArraysChainRulesCoreExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,10 @@ function rrule(::Type{T}, x::Tuple) where {T <: SArray}
2929
return T(x), ∇Array
3030
end
3131

32+
function rrule(::Type{T}, xs::Number...) where {T <: SVector}
33+
project_x = ProjectTo(xs)
34+
∇Array(∂y) = (NoTangent(), project_x(∂y)...)
35+
return T(xs...), ∇Array
36+
end
37+
3238
end

test/chainrules.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,7 @@ using StaticArrays, ChainRulesCore, ChainRulesTestUtils, Test
66
test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0))
77
test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0))
88
test_rrule(SVector{4}, (1.0, 1.0, 1.0, 1.0))
9+
test_rrule(SVector{4}, 1.0, 1.0, 1.0, 1.0)
10+
test_rrule(SVector{4}, 1.0, 1.0f0, 1.0, 1.0f0)
911
end
1012
end

0 commit comments

Comments
 (0)