Skip to content

Commit 91f4857

Browse files
Merge pull request #1224 from avik-pal/ap/chainrules
Add ChainRules Overloads to StaticArrays
2 parents d47c771 + f5aef04 commit 91f4857

File tree

4 files changed

+54
-2
lines changed

4 files changed

+54
-2
lines changed

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.7.0"
3+
version = "1.8.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -10,25 +10,30 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[weakdeps]
13+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1314
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1415

1516
[extensions]
17+
StaticArraysChainRulesCoreExt = "ChainRulesCore"
1618
StaticArraysStatisticsExt = "Statistics"
1719

1820
[compat]
1921
Aqua = "0.7"
22+
ChainRulesCore = "1"
2023
PrecompileTools = "1"
2124
StaticArraysCore = "~1.4.0"
2225
julia = "1.6"
2326

2427
[extras]
2528
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2629
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
30+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
31+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2732
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
2833
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2934
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3035
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3136
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3237

3338
[targets]
34-
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua"]
39+
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua", "ChainRulesTestUtils", "ChainRulesCore"]

ext/StaticArraysChainRulesCoreExt.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module StaticArraysChainRulesCoreExt
2+
3+
using StaticArrays
4+
# ChainRulesCore imports
5+
import ChainRulesCore: NoTangent, ProjectTo, Tangent, project_type, rrule
6+
import ChainRulesCore as CRC
7+
8+
# Projecting a tuple to SMatrix leads to CRC._projection_mismatch by default, so
9+
# overloaded here
10+
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::StaticArraysCore.SArray)
11+
dy = reshape(dx, axes(project.elements))
12+
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
13+
return project_type(project)(dz...)
14+
end
15+
16+
# Project SArray to SArray
17+
function ProjectTo(x::SArray{S, T}) where {S, T}
18+
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = S)
19+
end
20+
21+
function (project::ProjectTo{SArray})(dx::AbstractArray{S, M}) where {S, M}
22+
return SArray{project.axes}(dx)
23+
end
24+
25+
# Adjoint for SArray constructor
26+
function rrule(::Type{T}, x::Tuple) where {T <: SArray}
27+
project_x = ProjectTo(x)
28+
∇Array(∂y) = (NoTangent(), project_x(∂y))
29+
return T(x), ∇Array
30+
end
31+
32+
end

test/chainrules.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, Test
2+
3+
@testset "Chain Rules Integration" begin
4+
@testset "Projection" begin
5+
test_rrule(SMatrix{1, 4}, (1.0, 1.0, 1.0, 1.0))
6+
test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0))
7+
test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0))
8+
test_rrule(SVector{4}, (1.0, 1.0, 1.0, 1.0))
9+
end
10+
end

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,9 @@ if TEST_GROUP ∈ ["", "all", "group-B"]
8888
addtests("io.jl")
8989
addtests("svd.jl")
9090
addtests("unitful.jl")
91+
92+
# chain rules integration via pkg extensions is available only in Julia 1.9+
93+
if VERSION v"1.9-"
94+
addtests("chainrules.jl")
95+
end
9196
end

0 commit comments

Comments
 (0)