Skip to content

Commit 4892130

Browse files
committed
Add ChainRules Overloads to StaticArrays
1 parent d47c771 commit 4892130

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.7.0"
3+
version = "1.7.1"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
89
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -17,6 +18,7 @@ StaticArraysStatisticsExt = "Statistics"
1718

1819
[compat]
1920
Aqua = "0.7"
21+
ChainRulesCore = "1"
2022
PrecompileTools = "1"
2123
StaticArraysCore = "~1.4.0"
2224
julia = "1.6"

src/StaticArrays.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ using LinearAlgebra: checksquare
2121

2222
using PrecompileTools
2323

24+
# ChainRulesCore imports
25+
import ChainRulesCore: ProjectTo, Tangent, project_type, rrule
26+
import ChainRulesCore as CRC
27+
2428
# StaticArraysCore imports
2529
# there is intentionally no "using StaticArraysCore" to not take all symbols exported
2630
# from StaticArraysCore to make transitioning definitions to StaticArraysCore easier.
@@ -133,6 +137,8 @@ include("flatten.jl")
133137
include("io.jl")
134138
include("pinv.jl")
135139

140+
include("chainrules.jl")
141+
136142
@static if !isdefined(Base, :get_extension) # VERSION < v"1.9-"
137143
include("../ext/StaticArraysStatisticsExt.jl")
138144
end

src/chainrules.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so
2+
# overloaded here
3+
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::StaticArraysCore.SArray)
4+
dy = reshape(dx, axes(project.elements))
5+
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
6+
return project_type(project)(dz...)
7+
end
8+
9+
# Project SArray to SArray
10+
function ProjectTo(x::SArray{S, T}) where {S, T}
11+
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = S)
12+
end
13+
14+
function (project::ProjectTo{SArray})(dx::AbstractArray{S, M}) where {S, M}
15+
return SArray{project.axes}(dx)
16+
end
17+
18+
# Adjoint for SArray constructor
19+
function rrule(::Type{T}, x::Tuple) where {T <: SArray}
20+
project_x = ProjectTo(x)
21+
∇Array(∂y) = (NoTangent(), project_x(∂y))
22+
return T(x), ∇Array
23+
end

0 commit comments

Comments
 (0)