Skip to content

Commit 9b4259b

Browse files
committed
Use matrix size in structured broadcast
1 parent 9d139f5 commit 9b4259b

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

src/structuredbroadcast.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ Broadcast.BroadcastStyle(T::StructuredMatrixStyle{Matrix}, ::StructuredMatrixSty
6161
Broadcast.BroadcastStyle(::StructuredMatrixStyle, ::StructuredMatrixStyle) = DefaultArrayStyle{2}()
6262

6363
# And a definition akin to similar using the structured type:
64-
structured_broadcast_alloc(bc, ::Type{Diagonal}, ::Type{ElType}, n) where {ElType} =
65-
Diagonal(Array{ElType}(undef, n))
64+
structured_broadcast_alloc(bc, ::Type{Diagonal}, ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} =
65+
Diagonal(Array{ElType}(undef, sz[1]))
6666
# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion
6767
# system will return Tridiagonal when there's more than one Bidiagonal, but when
6868
# there's only one, we need to make figure out upper or lower
@@ -75,28 +75,33 @@ find_uplo(a::Bidiagonal) = a.uplo
7575
find_uplo(a) = nothing
7676
find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nested(bc), init=nothing)
7777

78-
function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) where {ElType}
78+
function structured_broadcast_alloc(bc, ::Type{Bidiagonal},
79+
::Type{ElType}, sz::NTuple{2,Integer}) where {ElType}
80+
n = sz[1]
7981
uplo = n > 0 ? find_uplo(bc) : 'U'
8082
n1 = max(n - 1, 0)
8183
if count_structedmatrix(Bidiagonal, bc) > 1 && uplo == 'T'
8284
return Tridiagonal(Array{ElType}(undef, n1), Array{ElType}(undef, n), Array{ElType}(undef, n1))
8385
end
8486
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n1), uplo)
8587
end
86-
structured_broadcast_alloc(bc, ::Type{SymTridiagonal}, ::Type{ElType}, n) where {ElType} =
88+
function structured_broadcast_alloc(bc, ::Type{SymTridiagonal},
89+
::Type{ElType}, sz::NTuple{2,Integer}) where {ElType}
90+
n = sz[1]
8791
SymTridiagonal(Array{ElType}(undef, n),Array{ElType}(undef, max(0,n-1)))
88-
structured_broadcast_alloc(bc, ::Type{Tridiagonal}, ::Type{ElType}, n) where {ElType} =
89-
Tridiagonal(Array{ElType}(undef, max(0,n-1)),Array{ElType}(undef, n),Array{ElType}(undef, max(0,n-1)))
90-
structured_broadcast_alloc(bc, ::Type{LowerTriangular}, ::Type{ElType}, n) where {ElType} =
91-
LowerTriangular(Array{ElType}(undef, n, n))
92-
structured_broadcast_alloc(bc, ::Type{UpperTriangular}, ::Type{ElType}, n) where {ElType} =
93-
UpperTriangular(Array{ElType}(undef, n, n))
94-
structured_broadcast_alloc(bc, ::Type{UnitLowerTriangular}, ::Type{ElType}, n) where {ElType} =
95-
UnitLowerTriangular(Array{ElType}(undef, n, n))
96-
structured_broadcast_alloc(bc, ::Type{UnitUpperTriangular}, ::Type{ElType}, n) where {ElType} =
97-
UnitUpperTriangular(Array{ElType}(undef, n, n))
98-
structured_broadcast_alloc(bc, ::Type{Matrix}, ::Type{ElType}, n) where {ElType} =
99-
Array{ElType}(undef, n, n)
92+
end
93+
function structured_broadcast_alloc(bc, ::Type{Tridiagonal},
94+
::Type{ElType}, sz::NTuple{2,Integer}) where {ElType}
95+
n = sz[1]
96+
n1 = max(0,n-1)
97+
Tridiagonal(Array{ElType}(undef, n1),Array{ElType}(undef, n),Array{ElType}(undef, n1))
98+
end
99+
function structured_broadcast_alloc(bc, ::Type{T}, ::Type{ElType},
100+
sz::NTuple{2,Integer}) where {ElType,T<:UpperOrLowerTriangular}
101+
T(Array{ElType}(undef, sz))
102+
end
103+
structured_broadcast_alloc(bc, ::Type{Matrix}, ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} =
104+
Array{ElType}(undef, sz)
100105

101106
# A _very_ limited list of structure-preserving functions known at compile-time. This list is
102107
# derived from the formerly-implemented `broadcast` methods in 0.6. Note that this must
@@ -172,7 +177,7 @@ function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType})
172177
inds = axes(bc)
173178
fzerobc = fzeropreserving(bc)
174179
if isstructurepreserving(bc) || (fzerobc && !(T <: Union{UnitLowerTriangular,UnitUpperTriangular}))
175-
return structured_broadcast_alloc(bc, T, ElType, length(inds[1]))
180+
return structured_broadcast_alloc(bc, T, ElType, map(length, inds))
176181
elseif fzerobc && T <: UnitLowerTriangular
177182
return similar(convert(Broadcasted{StructuredMatrixStyle{LowerTriangular}}, bc), ElType)
178183
elseif fzerobc && T <: UnitUpperTriangular

0 commit comments

Comments
 (0)