Skip to content

Commit 470a154

Browse files
committed
Convert to StaticArrayStyle
1 parent 1b3c1bd commit 470a154

File tree

2 files changed

+73
-3
lines changed

2 files changed

+73
-3
lines changed

src/staticarrays_support.jl

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import StaticArrays: StaticArray, FieldArray, tuple_prod
1+
using StaticArrays: StaticArrays, StaticArray, FieldArray, tuple_prod, StaticArrayStyle
22

33
"""
44
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
@@ -26,4 +26,62 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
2626
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
2727
end
2828
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
29-
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
29+
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
30+
31+
# Broadcast overload
32+
import StaticArrays: Size, isstatic, similar_type
33+
using StaticArrays: first_statictype, broadcast_sizes, SOneTo
34+
import Base.Broadcast: instantiate
35+
StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
36+
function instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
37+
bc′ = instantiate(convert(Broadcasted{StaticArrayStyle{M}}, bc))
38+
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
39+
end
40+
function Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing)
41+
return StaticArrays.static_combine_axes(bc.args...)
42+
end
43+
44+
# StaticArrayStyle has no similar defined.
45+
# Overload `Base.copy` instead.
46+
@inline function Base.copy(B::Broadcasted{<:StructStaticArrayStyle})
47+
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
48+
argsizes = broadcast_sizes(as...)
49+
ax = axes(B)
50+
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.")
51+
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
52+
end
53+
@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
54+
AT = first_statictype(a...)
55+
if prod(newsize) == 0
56+
# Use inference to get eltype in empty case (see also comments in _map)
57+
eltys = Tuple{map(eltype, a)...}
58+
T = Core.Compiler.return_type(f, eltys)
59+
return _struct_static_similar(T, AT, sz, ())
60+
end
61+
elements = StaticArrays.__broadcast(f, sz, s, a...)
62+
return _struct_static_similar(eltype(elements), AT, sz, elements)
63+
end
64+
function _struct_static_similar(::Type{ET}, ::Type{AT}, sz, elements::Tuple) where {ET, AT}
65+
if isnonemptystructtype(ET)
66+
arrs = ntuple(Val(fieldcount(ET))) do i
67+
similar_type(AT, fieldtype(ET, i), sz)(_getfields(elements, i))
68+
end
69+
return StructArray{ET}(arrs)
70+
else
71+
return similar_type(AT, ET, sz)(elements)
72+
end
73+
end
74+
75+
@inline function _getfields(x::Tuple, i::Int)
76+
if @generated
77+
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
78+
else
79+
return map(Base.Fix2(getfield, i), x)
80+
end
81+
end
82+
83+
Size(::Type{SA}) where {SA<:StructArray} = Size(fieldtype(array_types(SA), 1))
84+
isstatic(x::StructArray) = isstatic(component(x, 1))
85+
function similar_type(::Type{SA}, ::Type{T}, s::Size{S}) where {SA<:StructArray, T, S}
86+
return similar_type(fieldtype(array_types(SA), 1), T, s)
87+
end

test/runtests.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1160,12 +1160,24 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
11601160
end
11611161
testset = Any[StructArray([1;2+im]),
11621162
1:2,
1163-
(1,2),
1163+
(1,2),
1164+
StructArray(@SArray [1 1+2im]),
1165+
(@SArray [1 2])
11641166
]
11651167
for aa in testset, bb in testset, cc in testset
11661168
_test(aa, bb, cc)
11671169
end
11681170
end
1171+
1172+
@testset "StructStaticArray" begin
1173+
bclog(s) = log.(s)
1174+
test_allocated(f, s) = @test (@allocated f(s)) == 0
1175+
a = @SMatrix [float(i) for i in 1:10, j in 1:10]
1176+
b = @SMatrix [0. for i in 1:10, j in 1:10]
1177+
s = StructArray{ComplexF64}((a , b))
1178+
@test (@inferred bclog(s)) isa typeof(s)
1179+
test_allocated(bclog, s)
1180+
end
11691181
end
11701182

11711183
@testset "map" begin

0 commit comments

Comments
 (0)