|
1 |
| -import StaticArraysCore: StaticArray, FieldArray, tuple_prod |
| 1 | +using StaticArraysCore: StaticArray, FieldArray, tuple_prod |
2 | 2 |
|
3 | 3 | """
|
4 | 4 | StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
|
@@ -27,3 +27,42 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
|
27 | 27 | end
|
28 | 28 | StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
|
29 | 29 | StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
|
| 30 | + |
| 31 | +# Broadcast overload |
| 32 | +using StaticArraysCore: StaticArrayStyle |
| 33 | +import StaticArraysCore: Size, is_staticarray_like, similar_type |
| 34 | +StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N} |
| 35 | +function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} |
| 36 | + bc′ = Broadcast.instantiate(convert(Broadcasted{StaticArrayStyle{M}}, bc)) |
| 37 | + return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′) |
| 38 | +end |
| 39 | +function Broadcast._axes(bc::Broadcasted{StructStaticArrayStyle{M}}, ::Nothing) where {M} |
| 40 | + return Broadcast._axes(convert(Broadcasted{StaticArrayStyle{M}}, bc), nothing) |
| 41 | +end |
| 42 | + |
| 43 | +# StaticArrayStyle has no similar defined. |
| 44 | +# Overload `Base.copy` instead. |
| 45 | +@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} |
| 46 | + sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc)) |
| 47 | + ET = eltype(sa) |
| 48 | + isnonemptystructtype(ET) || return sa |
| 49 | + elements = Tuple(sa) |
| 50 | + arrs = ntuple(Val(fieldcount(ET))) do i |
| 51 | + similar_type(sa, fieldtype(ET, i), Size(sa))(_getfields(elements, i)) |
| 52 | + end |
| 53 | + return StructArray{ET}(arrs) |
| 54 | +end |
| 55 | + |
| 56 | +@inline function _getfields(x::Tuple, i::Int) |
| 57 | + if @generated |
| 58 | + return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...) |
| 59 | + else |
| 60 | + return map(Base.Fix2(getfield, i), x) |
| 61 | + end |
| 62 | +end |
| 63 | + |
| 64 | +Size(::Type{SA}) where {SA<:StructArray} = Size(fieldtype(array_types(SA), 1)) |
| 65 | +is_staticarray_like(x::StructArray) = any(is_staticarray_like, components(x)) |
| 66 | +function similar_type(::Type{SA}, ::Type{T}, s::Size{S}) where {SA<:StructArray, T, S} |
| 67 | + return similar_type(fieldtype(array_types(SA), 1), T, s) |
| 68 | +end |
0 commit comments