Skip to content

Commit afbc7bc

Browse files
Merge pull request #449 from oscardssmith/os/refactor-StaticArray-ext
Make extension depend on `StaticArrays` and optimize `lu_instance`
2 parents 0187aa6 + 50e2776 commit afbc7bc

File tree

5 files changed

+38
-34
lines changed

5 files changed

+38
-34
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
1515
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1616
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1717
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
18-
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
18+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1919
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2020

2121
[extensions]
@@ -27,7 +27,7 @@ ArrayInterfaceChainRulesExt = "ChainRules"
2727
ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore"
2828
ArrayInterfaceReverseDiffExt = "ReverseDiff"
2929
ArrayInterfaceSparseArraysExt = "SparseArrays"
30-
ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore"
30+
ArrayInterfaceStaticArraysExt = "StaticArrays"
3131
ArrayInterfaceTrackerExt = "Tracker"
3232

3333
[compat]
@@ -66,4 +66,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6666
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
6767

6868
[targets]
69-
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays"]
69+
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays"]

ext/ArrayInterfaceStaticArraysCoreExt.jl

Lines changed: 0 additions & 30 deletions
This file was deleted.

ext/ArrayInterfaceStaticArraysExt.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
module ArrayInterfaceStaticArraysExt
2+
3+
import ArrayInterface
4+
using LinearAlgebra
5+
import StaticArrays: SArray, SMatrix, SVector, StaticMatrix, StaticArray, SizedArray, MArray, MMatrix, LU
6+
7+
function ArrayInterface.undefmatrix(::MArray{S, T, N, L}) where {S, T, N, L}
8+
return MMatrix{L, L, T, L*L}(undef)
9+
end
10+
# SArray doesn't have an undef constructor and is going to be small enough that this is fine.
11+
function ArrayInterface.undefmatrix(s::SArray)
12+
v = vec(s)
13+
return v.*v'
14+
end
15+
16+
ArrayInterface.ismutable(::Type{<:StaticArray}) = false
17+
ArrayInterface.ismutable(::Type{<:MArray}) = true
18+
ArrayInterface.ismutable(::Type{<:SizedArray}) = true
19+
20+
ArrayInterface.can_setindex(::Type{<:StaticArray}) = false
21+
ArrayInterface.can_setindex(::Type{<:MArray}) = true
22+
ArrayInterface.buffer(A::Union{SArray, MArray}) = getfield(A, :data)
23+
24+
function ArrayInterface.lu_instance(A::SMatrix{N,N}) where {N}
25+
LU(LowerTriangular(A), UpperTriangular(A), SVector{N}(1:N))
26+
end
27+
28+
function ArrayInterface.lu_instance(A::StaticMatrix{N,N}) where {N}
29+
lu(one(A))
30+
end
31+
32+
ArrayInterface.restructure(x::SArray{S}, y) where {S} = SArray{S}(y)
33+
34+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414
@time @safetestset "BlockBandedMatrices" begin include("blockbandedmatrices.jl") end
1515
@time @safetestset "Core" begin include("core.jl") end
1616
@time @safetestset "AD Integration" begin include("ad.jl") end
17-
@time @safetestset "StaticArraysCore" begin include("staticarrayscore.jl") end
17+
@time @safetestset "StaticArrays" begin include("staticarrays.jl") end
1818
@time @safetestset "ChainRules" begin include("chainrules.jl") end
1919
end
2020

File renamed without changes.

0 commit comments

Comments
 (0)