@@ -61,8 +61,8 @@ Broadcast.BroadcastStyle(T::StructuredMatrixStyle{Matrix}, ::StructuredMatrixSty
61
61
Broadcast. BroadcastStyle (:: StructuredMatrixStyle , :: StructuredMatrixStyle ) = DefaultArrayStyle {2} ()
62
62
63
63
# And a definition akin to similar using the structured type:
64
- structured_broadcast_alloc (bc, :: Type{Diagonal} , :: Type{ElType} , n ) where {ElType} =
65
- Diagonal (Array {ElType} (undef, n ))
64
+ structured_broadcast_alloc (bc, :: Type{Diagonal} , :: Type{ElType} , sz :: NTuple{2,Integer} ) where {ElType} =
65
+ Diagonal (Array {ElType} (undef, sz[ 1 ] ))
66
66
# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion
67
67
# system will return Tridiagonal when there's more than one Bidiagonal, but when
68
68
# there's only one, we need to make figure out upper or lower
@@ -75,28 +75,33 @@ find_uplo(a::Bidiagonal) = a.uplo
75
75
find_uplo (a) = nothing
76
76
find_uplo (bc:: Broadcasted ) = mapfoldl (find_uplo, merge_uplos, Broadcast. cat_nested (bc), init= nothing )
77
77
78
- function structured_broadcast_alloc (bc, :: Type{Bidiagonal} , :: Type{ElType} , n) where {ElType}
78
+ function structured_broadcast_alloc (bc, :: Type{Bidiagonal} ,
79
+ :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType}
80
+ n = sz[1 ]
79
81
uplo = n > 0 ? find_uplo (bc) : ' U'
80
82
n1 = max (n - 1 , 0 )
81
83
if count_structedmatrix (Bidiagonal, bc) > 1 && uplo == ' T'
82
84
return Tridiagonal (Array {ElType} (undef, n1), Array {ElType} (undef, n), Array {ElType} (undef, n1))
83
85
end
84
86
return Bidiagonal (Array {ElType} (undef, n),Array {ElType} (undef, n1), uplo)
85
87
end
86
- structured_broadcast_alloc (bc, :: Type{SymTridiagonal} , :: Type{ElType} , n) where {ElType} =
88
+ function structured_broadcast_alloc (bc, :: Type{SymTridiagonal} ,
89
+ :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType}
90
+ n = sz[1 ]
87
91
SymTridiagonal (Array {ElType} (undef, n),Array {ElType} (undef, max (0 ,n- 1 )))
88
- structured_broadcast_alloc (bc, :: Type{Tridiagonal} , :: Type{ElType} , n) where {ElType} =
89
- Tridiagonal (Array {ElType} (undef, max (0 ,n- 1 )),Array {ElType} (undef, n),Array {ElType} (undef, max (0 ,n- 1 )))
90
- structured_broadcast_alloc (bc, :: Type{LowerTriangular} , :: Type{ElType} , n) where {ElType} =
91
- LowerTriangular (Array {ElType} (undef, n, n))
92
- structured_broadcast_alloc (bc, :: Type{UpperTriangular} , :: Type{ElType} , n) where {ElType} =
93
- UpperTriangular (Array {ElType} (undef, n, n))
94
- structured_broadcast_alloc (bc, :: Type{UnitLowerTriangular} , :: Type{ElType} , n) where {ElType} =
95
- UnitLowerTriangular (Array {ElType} (undef, n, n))
96
- structured_broadcast_alloc (bc, :: Type{UnitUpperTriangular} , :: Type{ElType} , n) where {ElType} =
97
- UnitUpperTriangular (Array {ElType} (undef, n, n))
98
- structured_broadcast_alloc (bc, :: Type{Matrix} , :: Type{ElType} , n) where {ElType} =
99
- Array {ElType} (undef, n, n)
92
+ end
93
+ function structured_broadcast_alloc (bc, :: Type{Tridiagonal} ,
94
+ :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType}
95
+ n = sz[1 ]
96
+ n1 = max (0 ,n- 1 )
97
+ Tridiagonal (Array {ElType} (undef, n1),Array {ElType} (undef, n),Array {ElType} (undef, n1))
98
+ end
99
+ function structured_broadcast_alloc (bc, :: Type{T} , :: Type{ElType} ,
100
+ sz:: NTuple{2,Integer} ) where {ElType,T<: UpperOrLowerTriangular }
101
+ T (Array {ElType} (undef, sz))
102
+ end
103
+ structured_broadcast_alloc (bc, :: Type{Matrix} , :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType} =
104
+ Array {ElType} (undef, sz)
100
105
101
106
# A _very_ limited list of structure-preserving functions known at compile-time. This list is
102
107
# derived from the formerly-implemented `broadcast` methods in 0.6. Note that this must
@@ -172,7 +177,7 @@ function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType})
172
177
inds = axes (bc)
173
178
fzerobc = fzeropreserving (bc)
174
179
if isstructurepreserving (bc) || (fzerobc && ! (T <: Union{UnitLowerTriangular,UnitUpperTriangular} ))
175
- return structured_broadcast_alloc (bc, T, ElType, length ( inds[ 1 ] ))
180
+ return structured_broadcast_alloc (bc, T, ElType, map (length, inds))
176
181
elseif fzerobc && T <: UnitLowerTriangular
177
182
return similar (convert (Broadcasted{StructuredMatrixStyle{LowerTriangular}}, bc), ElType)
178
183
elseif fzerobc && T <: UnitUpperTriangular
0 commit comments