Skip to content

Commit 3b38c22

Browse files
N5N3oschulz
andauthored
Move StaticArrays support to extension (#265)
* Use weakdeps on Julia v1.9 Update Project.toml * move StructStaticArray broadcast to ext * fix doctest * move `Adapt` to ext And use curried adapter to avoid possible instability * Apply suggestions from code review * Adopt code style suggestion. * restrict Adapt compat * Add empty bc test. * define `__broadcast` ourselves --------- Co-authored-by: Oliver Schulz <oschulz@mpp.mpg.de>
1 parent 99f0556 commit 3b38c22

11 files changed

+209
-206
lines changed

Project.toml

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,32 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
88
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
99
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
10-
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
10+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1111
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1212

13+
[weakdeps]
14+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
15+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
16+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
17+
18+
[extensions]
19+
StructArraysAdaptExt = "Adapt"
20+
StructArraysGPUArraysCoreExt = "GPUArraysCore"
21+
StructArraysStaticArraysExt = "StaticArrays"
22+
1323
[compat]
14-
Adapt = "1, 2, 3"
24+
Adapt = "3.4"
1525
ConstructionBase = "1"
1626
DataAPI = "1"
1727
GPUArraysCore = "0.1.2"
1828
StaticArrays = "1.5.6"
19-
StaticArraysCore = "1.3"
2029
Tables = "1"
2130
julia = "1.6"
2231

2332
[extras]
33+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2434
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
35+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2536
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2637
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2738
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
@@ -32,4 +43,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
3243
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
3344

3445
[targets]
35-
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays"]
46+
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "Adapt"]

docs/src/advanced.md

Lines changed: 32 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -6,89 +6,56 @@ StructArrays support structures with custom data layout. The user is required to
66

77
Here is an example of a type `MyType` that has as custom fields either its field `data` or fields of its field `rest` (which is a named tuple):
88

9-
```jldoctest advanced1
10-
julia> using StructArrays
9+
```@repl advanced1
10+
using StructArrays
1111
12-
julia> struct MyType{T, NT<:NamedTuple}
13-
data::T
14-
rest::NT
15-
end
12+
struct MyType{T, NT<:NamedTuple}
13+
data::T
14+
rest::NT
15+
end
1616
17-
julia> MyType(x; kwargs...) = MyType(x, values(kwargs))
18-
MyType
17+
MyType(x; kwargs...) = MyType(x, values(kwargs))
1918
```
2019

2120
Let's create a small array of these objects:
2221

23-
```jldoctest advanced1
24-
julia> s = [MyType(i/5, a=6-i, b=2) for i in 1:5]
25-
5-element Vector{MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}}:
26-
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.2, (a = 5, b = 2))
27-
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.4, (a = 4, b = 2))
28-
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.6, (a = 3, b = 2))
29-
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.8, (a = 2, b = 2))
30-
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(1.0, (a = 1, b = 2))
22+
```@repl advanced1
23+
s = [MyType(i/5, a=6-i, b=2) for i in 1:5]
3124
```
3225

3326
The default `StructArray` does not unpack the `NamedTuple`:
3427

35-
```jldoctest advanced1
36-
julia> sa = StructArray(s);
37-
38-
julia> sa.rest
39-
5-element Vector{NamedTuple{(:a, :b), Tuple{Int64, Int64}}}:
40-
(a = 5, b = 2)
41-
(a = 4, b = 2)
42-
(a = 3, b = 2)
43-
(a = 2, b = 2)
44-
(a = 1, b = 2)
45-
46-
julia> sa.a
47-
ERROR: type NamedTuple has no field a
48-
Stacktrace:
49-
[1] component
50-
[...]
28+
```@repl advanced1
29+
sa = StructArray(s);
30+
sa.rest
31+
sa.a
5132
```
5233

5334
Suppose we wish to give the keywords their own fields. We can define custom `staticschema`, `component`, and `createinstance` methods for `MyType`:
5435

55-
```jldoctest advanced1
56-
julia> function StructArrays.staticschema(::Type{MyType{T, NamedTuple{names, types}}}) where {T, names, types}
57-
# Define the desired names and eltypes of the "fields"
58-
return NamedTuple{(:data, names...), Base.tuple_type_cons(T, types)}
59-
end;
60-
61-
julia> function StructArrays.component(m::MyType, key::Symbol)
62-
# Define a component-extractor
63-
return key === :data ? getfield(m, 1) : getfield(getfield(m, 2), key)
64-
end;
65-
66-
julia> function StructArrays.createinstance(::Type{MyType{T, NT}}, x, args...) where {T, NT}
67-
# Generate an instance of MyType from components
68-
return MyType(x, NT(args))
69-
end;
36+
```@repl advanced1
37+
function StructArrays.staticschema(::Type{MyType{T, NamedTuple{names, types}}}) where {T, names, types}
38+
# Define the desired names and eltypes of the "fields"
39+
return NamedTuple{(:data, names...), Base.tuple_type_cons(T, types)}
40+
end;
41+
42+
function StructArrays.component(m::MyType, key::Symbol)
43+
# Define a component-extractor
44+
return key === :data ? getfield(m, 1) : getfield(getfield(m, 2), key)
45+
end;
46+
47+
function StructArrays.createinstance(::Type{MyType{T, NT}}, x, args...) where {T, NT}
48+
# Generate an instance of MyType from components
49+
return MyType(x, NT(args))
50+
end;
7051
```
7152

7253
and now:
7354

74-
```jldoctest advanced1
75-
julia> sa = StructArray(s);
76-
77-
julia> sa.a
78-
5-element Vector{Int64}:
79-
5
80-
4
81-
3
82-
2
83-
1
84-
85-
julia> sa.b
86-
5-element Vector{Int64}:
87-
2
88-
2
89-
2
90-
2
91-
2
55+
```@repl advanced1
56+
sa = StructArray(s);
57+
sa.a
58+
sa.b
9259
```
9360

9461
The above strategy has been tested and implemented in [GeometryBasics.jl](https://github.com/JuliaGeometry/GeometryBasics.jl).

docs/src/index.md

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,73 +9,39 @@ The package was largely inspired by the `Columns` type in [IndexedTables](https:
99
## Collection and initialization
1010

1111
One can create a `StructArray` by providing the struct type and a tuple or NamedTuple of field arrays:
12-
```jldoctest intro
13-
julia> using StructArrays
14-
15-
julia> struct Foo{T}
16-
a::T
17-
b::T
18-
end
19-
20-
julia> adata = [1 2; 3 4]; bdata = [10 20; 30 40];
21-
22-
julia> x = StructArray{Foo}((adata, bdata))
23-
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Foo:
24-
Foo{Int64}(1, 10) Foo{Int64}(2, 20)
25-
Foo{Int64}(3, 30) Foo{Int64}(4, 40)
12+
```@repl intro
13+
using StructArrays
14+
struct Foo{T}
15+
a::T
16+
b::T
17+
end
18+
adata = [1 2; 3 4]; bdata = [10 20; 30 40];
19+
x = StructArray{Foo}((adata, bdata))
2620
```
2721

2822
You can also initialze a StructArray by passing in a NamedTuple, in which case the name (rather than the order) specifies how the input arrays are assigned to fields:
2923

30-
```jldoctest intro
31-
julia> x = StructArray{Foo}((b = adata, a = bdata)) # initialize a with bdata and vice versa
32-
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Foo:
33-
Foo{Int64}(10, 1) Foo{Int64}(20, 2)
34-
Foo{Int64}(30, 3) Foo{Int64}(40, 4)
24+
```@repl intro
25+
x = StructArray{Foo}((b = adata, a = bdata)) # initialize a with bdata and vice versa
3526
```
3627

3728
If a struct is not specified, a StructArray with Tuple or NamedTuple elements will be created:
38-
```jldoctest intro
39-
julia> x = StructArray((adata, bdata))
40-
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Tuple{Int64, Int64}:
41-
(1, 10) (2, 20)
42-
(3, 30) (4, 40)
43-
44-
julia> x = StructArray((a = adata, b = bdata))
45-
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype NamedTuple{(:a, :b), Tuple{Int64, Int64}}:
46-
(a = 1, b = 10) (a = 2, b = 20)
47-
(a = 3, b = 30) (a = 4, b = 40)
29+
```@repl intro
30+
x = StructArray((adata, bdata))
31+
x = StructArray((a = adata, b = bdata))
4832
```
4933

5034
It's also possible to create a `StructArray` by choosing a particular dimension to interpret as the components of a struct:
5135

52-
```jldoctest intro
53-
julia> x = StructArray{Complex{Int}}(adata; dims=1) # along dimension 1, the first item `re` and the second is `im`
54-
2-element StructArray(view(::Matrix{Int64}, 1, :), view(::Matrix{Int64}, 2, :)) with eltype Complex{Int64}:
55-
1 + 3im
56-
2 + 4im
57-
58-
julia> x = StructArray{Complex{Int}}(adata; dims=2) # along dimension 2, the first item `re` and the second is `im`
59-
2-element StructArray(view(::Matrix{Int64}, :, 1), view(::Matrix{Int64}, :, 2)) with eltype Complex{Int64}:
60-
1 + 2im
61-
3 + 4im
36+
```@repl intro
37+
x = StructArray{Complex{Int}}(adata; dims=1) # along dimension 1, the first item `re` and the second is `im`
38+
x = StructArray{Complex{Int}}(adata; dims=2) # along dimension 2, the first item `re` and the second is `im`
6239
```
6340

6441
One can also create a `StructArray` from an iterable of structs without creating an intermediate `Array`:
6542

66-
```jldoctest intro
67-
julia> StructArray(log(j+2.0*im) for j in 1:10)
68-
10-element StructArray(::Vector{Float64}, ::Vector{Float64}) with eltype ComplexF64:
69-
0.8047189562170501 + 1.1071487177940904im
70-
1.0397207708399179 + 0.7853981633974483im
71-
1.2824746787307684 + 0.5880026035475675im
72-
1.4978661367769954 + 0.4636476090008061im
73-
1.683647914993237 + 0.3805063771123649im
74-
1.8444397270569681 + 0.3217505543966422im
75-
1.985145956776061 + 0.27829965900511133im
76-
2.1097538525880535 + 0.24497866312686414im
77-
2.2213256282451583 + 0.21866894587394195im
78-
2.3221954495706862 + 0.19739555984988078im
43+
```@repl intro
44+
StructArray(log(j+2.0*im) for j in 1:10)
7945
```
8046

8147
Another option is to create an uninitialized `StructArray` and then fill it with data. Just like in normal arrays, this is done with the `undef` syntax:

ext/StructArraysAdaptExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module StructArraysAdaptExt
2+
# Use Adapt allows for automatic conversion of CPU to GPU StructArrays
3+
using Adapt, StructArrays
4+
Adapt.adapt_structure(to, s::StructArray) = replace_storage(adapt(to), s)
5+
end

ext/StructArraysGPUArraysCoreExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module StructArraysGPUArraysCoreExt
2+
3+
using StructArrays
4+
using StructArrays: map_params, array_types
5+
6+
using Base: tail
7+
8+
import GPUArraysCore
9+
10+
# for GPU broadcast
11+
import GPUArraysCore
12+
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
13+
backends = map_params(GPUArraysCore.backend, array_types(T))
14+
backend, others = backends[1], tail(backends)
15+
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
16+
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
17+
return backend
18+
end
19+
StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true
20+
21+
end # module

ext/StructArraysStaticArraysExt.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
module StructArraysStaticArraysExt
2+
3+
using StructArrays
4+
using StaticArrays: StaticArray, FieldArray, tuple_prod
5+
6+
"""
7+
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
8+
9+
The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`.
10+
```julia
11+
julia> StructArrays.staticschema(SVector{2, Float64})
12+
Tuple{Float64, Float64}
13+
```
14+
The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a
15+
struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct
16+
which subtypes `FieldArray`.
17+
"""
18+
@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
19+
return quote
20+
Base.@_inline_meta
21+
return NTuple{$(tuple_prod(S)), T}
22+
end
23+
end
24+
StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args)
25+
StructArrays.component(s::StaticArray, i) = getindex(s, i)
26+
27+
# invoke general fallbacks for a `FieldArray` type.
28+
@inline function StructArrays.staticschema(T::Type{<:FieldArray})
29+
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
30+
end
31+
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
32+
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
33+
34+
# Broadcast overload
35+
using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
36+
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype
37+
using StructArrays: isnonemptystructtype
38+
using Base.Broadcast: Broadcasted, _broadcast_getindex
39+
40+
# StaticArrayStyle has no similar defined.
41+
# Overload `try_struct_copy` instead.
42+
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
43+
flat = broadcast_flatten(bc); as = flat.args; f = flat.f
44+
argsizes = broadcast_sizes(as...)
45+
ax = axes(bc)
46+
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.")
47+
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
48+
end
49+
50+
# A functor generates the ith component of StructStaticBroadcast.
51+
struct Similar_ith{SA, E<:Tuple}
52+
elements::E
53+
Similar_ith{SA}(elements::Tuple) where {SA} = new{SA, typeof(elements)}(elements)
54+
end
55+
function (s::Similar_ith{SA})(i::Int) where {SA}
56+
ith_elements = ntuple(Val(length(s.elements))) do j
57+
getfield(s.elements[j], i)
58+
end
59+
ith_SA = similar_type(SA, fieldtype(eltype(SA), i))
60+
return @inbounds ith_SA(ith_elements)
61+
end
62+
63+
@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize}
64+
first_staticarray = first_statictype(a...)
65+
elements, ET = if prod(newsize) == 0
66+
# Use inference to get eltype in empty case (following StaticBroadcast defined in StaticArrays.jl)
67+
eltys = Tuple{map(eltype, a)...}
68+
(), Core.Compiler.return_type(f, eltys)
69+
else
70+
temp = __broadcast(f, sz, s, a...)
71+
temp, eltype(temp)
72+
end
73+
if isnonemptystructtype(ET)
74+
SA = similar_type(first_staticarray, ET, sz)
75+
arrs = ntuple(Similar_ith{SA}(elements), Val(fieldcount(ET)))
76+
return StructArray{ET}(arrs)
77+
else
78+
@inbounds return similar_type(first_staticarray, ET, sz)(elements)
79+
end
80+
end
81+
82+
# The `__broadcast` kernal is copied from `StaticArrays.jl`.
83+
# see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl
84+
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
85+
sizes = [sz.parameters[1] for sz s.parameters]
86+
87+
indices = CartesianIndices(newsize)
88+
exprs = similar(indices, Expr)
89+
for (j, current_ind) enumerate(indices)
90+
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
91+
exprs[j] = :(f($(exprs_vals...)))
92+
end
93+
94+
return quote
95+
Base.@_inline_meta
96+
return tuple($(exprs...))
97+
end
98+
end
99+
100+
broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
101+
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
102+
li = LinearIndices(oldsize)
103+
ind = _broadcast_getindex(li, newindex)
104+
return :(a[$i][$ind])
105+
end
106+
107+
end

0 commit comments

Comments
 (0)