Skip to content

Commit 088636d

Browse files
authored
broadcasting for Diagonal StaticArrays (#914)
* broadcasting for Diagonal StaticArrays * Add tests for SizedArrays * add tests for inplace broadcasting * version bump to v1.2.3
1 parent c29a9e2 commit 088636d

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.2.2"
3+
version = "1.2.3"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/broadcast.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ StaticArrayStyle{M}(::Val{N}) where {M,N} = StaticArrayStyle{N}()
1313
BroadcastStyle(::Type{<:StaticArray{<:Tuple, <:Any, N}}) where {N} = StaticArrayStyle{N}()
1414
BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
1515
BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
16+
BroadcastStyle(::Type{<:Diagonal{<:Any, <:StaticArray{<:Tuple, <:Any, 1}}}) = StaticArrayStyle{2}()
1617
# Precedence rules
1718
BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
1819
DefaultArrayStyle(Val(max(M, N)))
@@ -97,7 +98,7 @@ scalar_getindex(x) = x
9798
scalar_getindex(x::Ref) = x[]
9899

99100
@generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
100-
first_staticarray = a[findfirst(ai -> ai <: Union{StaticArray, Transpose{<:Any, <:StaticArray}, Adjoint{<:Any, <:StaticArray}}, a)]
101+
first_staticarray = a[findfirst(ai -> ai <: Union{StaticArray, Transpose{<:Any, <:StaticArray}, Adjoint{<:Any, <:StaticArray}, Diagonal{<:Any, <:StaticArray}}, a)]
101102

102103
if prod(newsize) == 0
103104
# Use inference to get eltype in empty case (see also comments in _map)

test/broadcast.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,4 +240,34 @@ end
240240
# Unfortunately this case of nested broadcasting is not inferred
241241
@test_broken @inferred(SA[1,2,3] .* (SA[1,0],))
242242
end
243+
244+
@testset "SDiagonal" begin
245+
for DS in Any[Diagonal(SVector{2}(1:2)), Diagonal(MVector{2}(1:2)), Diagonal(SizedArray{Tuple{2}}(1:2))],
246+
S in Any[SVector{2}(1:2), MVector{2}(1:2), SizedArray{Tuple{2}}(1:2)]
247+
@test DS .* S isa StaticArray
248+
@test DS .* S == collect(DS) .* collect(S)
249+
@test DS .* collect(S) == collect(DS) .* collect(S)
250+
@test DS .* S' isa StaticArray
251+
@test DS .* S' == collect(DS) .* collect(S)
252+
@test DS .* collect(S') == collect(DS) .* collect(S)
253+
@test S .* DS .* S' isa StaticArray
254+
@test S .* DS .* S' == collect(S) .* collect(DS) .* collect(S)
255+
DS2 = Diagonal(S)
256+
@test DS .* DS2 isa StaticArray
257+
@test DS .* DS2 == collect(DS) .* collect(DS2)
258+
@test DS .* collect(DS2) == collect(DS) .* collect(DS2)
259+
260+
# inplace broadcasting for mutable diagonal types
261+
DS2 = Diagonal(MVector{2}(diag(DS)))
262+
DS2 .*= S
263+
@test DS2 == DS .* S
264+
DS2 .= DS
265+
@test DS2 == DS
266+
DS2 .+= DS
267+
@test DS2 == DS .+ DS
268+
DS2 .= DS
269+
DS2 .*= DS
270+
@test DS2 == DS .* DS
271+
end
272+
end
243273
end

0 commit comments

Comments
 (0)