@@ -60,13 +60,21 @@ structured_broadcast_alloc(bc, ::Type{<:Diagonal}, ::Type{ElType}, n) where {ElT
60
60
# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion
61
61
# system will return Tridiagonal when there's more than one Bidiagonal, but when
62
62
# there's only one, we need to make figure out upper or lower
63
- find_bidiagonal () = throw (ArgumentError (" could not find Bidiagonal within broadcast expression" ))
64
- find_bidiagonal (a:: Bidiagonal , rest... ) = a
65
- find_bidiagonal (bc:: Broadcast.Broadcasted , rest... ) = find_bidiagonal (find_bidiagonal (bc. args... ), rest... )
66
- find_bidiagonal (x, rest... ) = find_bidiagonal (rest... )
63
+ merge_uplos (:: Nothing , :: Nothing ) = nothing
64
+ merge_uplos (a, :: Nothing ) = a
65
+ merge_uplos (:: Nothing , b) = b
66
+ merge_uplos (a, b) = a == b ? a : ' T'
67
+
68
+ find_uplo (a:: Bidiagonal ) = a. uplo
69
+ find_uplo (a) = nothing
70
+ find_uplo (bc:: Broadcasted ) = mapreduce (find_uplo, merge_uplos, bc. args, init= nothing )
71
+
67
72
function structured_broadcast_alloc (bc, :: Type{<:Bidiagonal} , :: Type{ElType} , n) where {ElType}
68
- ex = find_bidiagonal (bc)
69
- return Bidiagonal (Array {ElType} (undef, n),Array {ElType} (undef, n- 1 ), ex. uplo)
73
+ uplo = find_uplo (bc)
74
+ if uplo == ' T'
75
+ return Tridiagonal (Array {ElType} (undef, n- 1 ), Array {ElType} (undef, n), Array {ElType} (undef, n- 1 ))
76
+ end
77
+ return Bidiagonal (Array {ElType} (undef, n),Array {ElType} (undef, n- 1 ), uplo)
70
78
end
71
79
structured_broadcast_alloc (bc, :: Type{<:SymTridiagonal} , :: Type{ElType} , n) where {ElType} =
72
80
SymTridiagonal (Array {ElType} (undef, n),Array {ElType} (undef, n- 1 ))
0 commit comments