@@ -8,8 +8,12 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
8
8
StructuredMatrixStyle {T} (:: Val{2} ) where {T} = StructuredMatrixStyle {T} ()
9
9
StructuredMatrixStyle {T} (:: Val{N} ) where {T,N} = Broadcast. DefaultArrayStyle {N} ()
10
10
11
- const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
12
- for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
11
+ const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},
12
+ LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T},
13
+ UpperHessenberg{T}}
14
+ for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,
15
+ LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular,
16
+ UpperHessenberg)
13
17
@eval Broadcast. BroadcastStyle (:: Type{<:$ST} ) = $ (StructuredMatrixStyle {ST} ())
14
18
end
15
19
@@ -27,28 +31,46 @@ Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixSt
27
31
StructuredMatrixStyle {LowerTriangular} ()
28
32
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Diagonal} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} ) =
29
33
StructuredMatrixStyle {UpperTriangular} ()
34
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Diagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
35
+ StructuredMatrixStyle {UpperHessenberg} ()
30
36
31
37
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{Diagonal} ) =
32
38
StructuredMatrixStyle {Bidiagonal} ()
33
39
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
34
40
StructuredMatrixStyle {Tridiagonal} ()
41
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
42
+ StructuredMatrixStyle {UpperHessenberg} ()
43
+
35
44
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{SymTridiagonal} , :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
36
45
StructuredMatrixStyle {Tridiagonal} ()
46
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{SymTridiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
47
+ StructuredMatrixStyle {UpperHessenberg} ()
37
48
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Tridiagonal} , :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
38
49
StructuredMatrixStyle {Tridiagonal} ()
50
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Tridiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
51
+ StructuredMatrixStyle {UpperHessenberg} ()
39
52
40
53
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{LowerTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}} ) =
41
54
StructuredMatrixStyle {LowerTriangular} ()
42
55
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}} ) =
43
56
StructuredMatrixStyle {UpperTriangular} ()
57
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperTriangular} , :: StructuredMatrixStyle{UpperHessenberg} ) =
58
+ StructuredMatrixStyle {UpperHessenberg} ()
44
59
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitLowerTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}} ) =
45
60
StructuredMatrixStyle {LowerTriangular} ()
46
61
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitUpperTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}} ) =
47
62
StructuredMatrixStyle {UpperTriangular} ()
63
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitUpperTriangular} , :: StructuredMatrixStyle{UpperHessenberg} ) =
64
+ StructuredMatrixStyle {UpperHessenberg} ()
65
+
66
+ function Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperHessenberg} ,
67
+ :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,UnitUpperTriangular,UpperTriangular}} )
68
+ StructuredMatrixStyle {UpperHessenberg} ()
69
+ end
48
70
49
- Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} ) =
71
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg }} ) =
50
72
StructuredMatrixStyle {Matrix} ()
51
- Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} , :: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} ) =
73
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg }} , :: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} ) =
52
74
StructuredMatrixStyle {Matrix} ()
53
75
54
76
# Make sure that `StructuredMatrixStyle{Matrix}` doesn't ever end up falling
@@ -95,6 +117,8 @@ structured_broadcast_alloc(bc, ::Type{UnitLowerTriangular}, ::Type{ElType}, n) w
95
117
UnitLowerTriangular (Array {ElType} (undef, n, n))
96
118
structured_broadcast_alloc (bc, :: Type{UnitUpperTriangular} , :: Type{ElType} , n) where {ElType} =
97
119
UnitUpperTriangular (Array {ElType} (undef, n, n))
120
+ structured_broadcast_alloc (bc, :: Type{UpperHessenberg} , :: Type{ElType} , n) where {ElType} =
121
+ UpperHessenberg (Array {ElType} (undef, n, n))
98
122
structured_broadcast_alloc (bc, :: Type{Matrix} , :: Type{ElType} , n) where {ElType} =
99
123
Array {ElType} (undef, n, n)
100
124
@@ -288,6 +312,18 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
288
312
return dest
289
313
end
290
314
315
+ function copyto! (dest:: UpperHessenberg , bc:: Broadcasted{<:StructuredMatrixStyle} )
316
+ isvalidstructbc (dest, bc) || return copyto! (dest, convert (Broadcasted{Nothing}, bc))
317
+ axs = axes (dest)
318
+ axes (bc) == axs || Broadcast. throwdm (axes (bc), axs)
319
+ for j in axs[2 ]
320
+ for i in 1 : min (size (dest. data,1 ), j+ 1 )
321
+ @inbounds dest. data[i,j] = bc[CartesianIndex (i, j)]
322
+ end
323
+ end
324
+ return dest
325
+ end
326
+
291
327
# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
292
328
function map (f, A:: StructuredMatrix , Bs:: StructuredMatrix... )
293
329
sz = size (A)
0 commit comments