1
- import StaticArrays: StaticArray, FieldArray, tuple_prod
1
+ using StaticArrays: StaticArrays, StaticArray, FieldArray, tuple_prod, StaticArrayStyle
2
2
3
3
"""
4
4
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
@@ -26,4 +26,62 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
26
26
invoke (StructArrays. staticschema, Tuple{Type{<: Any }}, T)
27
27
end
28
28
StructArrays. component (s:: FieldArray , i) = invoke (StructArrays. component, Tuple{Any, Any}, s, i)
29
- StructArrays. createinstance (T:: Type{<:FieldArray} , args... ) = invoke (createinstance, Tuple{Type{<: Any }, Vararg}, T, args... )
29
+ StructArrays. createinstance (T:: Type{<:FieldArray} , args... ) = invoke (createinstance, Tuple{Type{<: Any }, Vararg}, T, args... )
30
+
31
+ # Broadcast overload
32
+ import StaticArrays: Size, isstatic, similar_type
33
+ using StaticArrays: first_statictype, broadcast_sizes, SOneTo
34
+ import Base. Broadcast: instantiate
35
+ StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
36
+ function instantiate (bc:: Broadcasted{StructStaticArrayStyle{M}} ) where {M}
37
+ bc′ = instantiate (convert (Broadcasted{StaticArrayStyle{M}}, bc))
38
+ return convert (Broadcasted{StructStaticArrayStyle{M}}, bc′)
39
+ end
40
+ function Broadcast. _axes (bc:: Broadcasted{<:StructStaticArrayStyle} , :: Nothing )
41
+ return StaticArrays. static_combine_axes (bc. args... )
42
+ end
43
+
44
+ # StaticArrayStyle has no similar defined.
45
+ # Overload `Base.copy` instead.
46
+ @inline function Base. copy (B:: Broadcasted{<:StructStaticArrayStyle} )
47
+ flat = Broadcast. flatten (B); as = flat. args; f = flat. f
48
+ argsizes = broadcast_sizes (as... )
49
+ ax = axes (B)
50
+ ax isa Tuple{Vararg{SOneTo}} || error (" Dimension is not static. Please file a bug." )
51
+ return _broadcast (f, Size (map (length, ax)), argsizes, as... )
52
+ end
53
+ @inline function _broadcast (f, sz:: Size{newsize} , s:: Tuple{Vararg{Size}} , a... ) where newsize
54
+ AT = first_statictype (a... )
55
+ if prod (newsize) == 0
56
+ # Use inference to get eltype in empty case (see also comments in _map)
57
+ eltys = Tuple{map (eltype, a)... }
58
+ T = Core. Compiler. return_type (f, eltys)
59
+ return _struct_static_similar (T, AT, sz, ())
60
+ end
61
+ elements = StaticArrays. __broadcast (f, sz, s, a... )
62
+ return _struct_static_similar (eltype (elements), AT, sz, elements)
63
+ end
64
+ function _struct_static_similar (:: Type{ET} , :: Type{AT} , sz, elements:: Tuple ) where {ET, AT}
65
+ if isnonemptystructtype (ET)
66
+ arrs = ntuple (Val (fieldcount (ET))) do i
67
+ similar_type (AT, fieldtype (ET, i), sz)(_getfields (elements, i))
68
+ end
69
+ return StructArray {ET} (arrs)
70
+ else
71
+ return similar_type (AT, ET, sz)(elements)
72
+ end
73
+ end
74
+
75
+ @inline function _getfields (x:: Tuple , i:: Int )
76
+ if @generated
77
+ return Expr (:tuple , (:(getfield (x[$ j], i)) for j in 1 : fieldcount (x)). .. )
78
+ else
79
+ return map (Base. Fix2 (getfield, i), x)
80
+ end
81
+ end
82
+
83
+ Size (:: Type{SA} ) where {SA<: StructArray } = Size (fieldtype (array_types (SA), 1 ))
84
+ isstatic (x:: StructArray ) = isstatic (component (x, 1 ))
85
+ function similar_type (:: Type{SA} , :: Type{T} , s:: Size{S} ) where {SA<: StructArray , T, S}
86
+ return similar_type (fieldtype (array_types (SA), 1 ), T, s)
87
+ end
0 commit comments