Skip to content

Commit 3fd8fb9

Browse files
author
Avik Pal
authored
Fix Type Instability in SArray Projection (#1227)
* Fix Type Instability * Add test using JLArrays
1 parent 79c2991 commit 3fd8fb9

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.8.1"
3+
version = "1.8.2"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -30,10 +30,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3030
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3131
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3232
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
33+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
3334
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
3435
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3536
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3637
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3738

3839
[targets]
39-
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua", "ChainRulesTestUtils", "ChainRulesCore"]
40+
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "JLArrays"]

ext/StaticArraysChainRulesCoreExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ end
1515

1616
# Project SArray to SArray
1717
function ProjectTo(x::SArray{S, T}) where {S, T}
18-
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = S)
18+
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = Size(x))
1919
end
2020

21-
function (project::ProjectTo{SArray})(dx::AbstractArray{S, M}) where {S, M}
22-
return SArray{project.axes}(dx)
23-
end
21+
@inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx)
22+
23+
(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.axes, dx)
2424

2525
# Adjoint for SArray constructor
2626
function rrule(::Type{T}, x::Tuple) where {T <: SArray}

test/chainrules.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, Test
1+
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test
22

33
@testset "Chain Rules Integration" begin
44
@testset "Projection" begin
@@ -9,4 +9,28 @@ using StaticArrays, ChainRulesCore, ChainRulesTestUtils, Test
99
test_rrule(SVector{4}, 1.0, 1.0, 1.0, 1.0)
1010
test_rrule(SVector{4}, 1.0, 1.0f0, 1.0, 1.0f0)
1111
end
12+
13+
@testset "Type Stability" begin
14+
x = ones(SMatrix{2, 2})
15+
y = ones(SVector{4})
16+
17+
@inferred ProjectTo(x)
18+
@inferred ProjectTo(y)
19+
@inferred ProjectTo(x)(y)
20+
@inferred ProjectTo(y)(x)
21+
22+
x = ones(SMatrix{2, 2, Float32})
23+
y = ones(SVector{4})
24+
25+
@inferred ProjectTo(x)
26+
@inferred ProjectTo(x)(y)
27+
@inferred ProjectTo(y)(x)
28+
end
29+
30+
@testset "Array of Structs Projection" begin
31+
x = JLArray(rand(SVector{3, Float64}, 10))
32+
@inferred ProjectTo(x)
33+
@inferred Union{Nothing, JLVector{SVector{3, Float64}}, DenseJLVector{SVector{3, Float64}}} ProjectTo(x)(x)
34+
@test ProjectTo(x)(x) isa JLArray
35+
end
1236
end

0 commit comments

Comments
 (0)