Skip to content

Commit 653b3a2

Browse files
committed
Structured broadcasting for UpperHessenberg
1 parent 41db513 commit 653b3a2

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

src/structuredbroadcast.jl

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
88
StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
99
StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()
1010

11-
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
12-
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
11+
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},
12+
LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T},
13+
UpperHessenberg{T}}
14+
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,
15+
LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular,
16+
UpperHessenberg)
1317
@eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}())
1418
end
1519

@@ -27,28 +31,46 @@ Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixSt
2731
StructuredMatrixStyle{LowerTriangular}()
2832
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) =
2933
StructuredMatrixStyle{UpperTriangular}()
34+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
35+
StructuredMatrixStyle{UpperHessenberg}()
3036

3137
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{Diagonal}) =
3238
StructuredMatrixStyle{Bidiagonal}()
3339
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}}) =
3440
StructuredMatrixStyle{Tridiagonal}()
41+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
42+
StructuredMatrixStyle{UpperHessenberg}()
43+
3544
Broadcast.BroadcastStyle(::StructuredMatrixStyle{SymTridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) =
3645
StructuredMatrixStyle{Tridiagonal}()
46+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{SymTridiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
47+
StructuredMatrixStyle{UpperHessenberg}()
3748
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Tridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) =
3849
StructuredMatrixStyle{Tridiagonal}()
50+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Tridiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
51+
StructuredMatrixStyle{UpperHessenberg}()
3952

4053
Broadcast.BroadcastStyle(::StructuredMatrixStyle{LowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) =
4154
StructuredMatrixStyle{LowerTriangular}()
4255
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) =
4356
StructuredMatrixStyle{UpperTriangular}()
57+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperTriangular}, ::StructuredMatrixStyle{UpperHessenberg}) =
58+
StructuredMatrixStyle{UpperHessenberg}()
4459
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitLowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) =
4560
StructuredMatrixStyle{LowerTriangular}()
4661
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitUpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) =
4762
StructuredMatrixStyle{UpperTriangular}()
63+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitUpperTriangular}, ::StructuredMatrixStyle{UpperHessenberg}) =
64+
StructuredMatrixStyle{UpperHessenberg}()
65+
66+
function Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperHessenberg},
67+
::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,UnitUpperTriangular,UpperTriangular}})
68+
StructuredMatrixStyle{UpperHessenberg}()
69+
end
4870

49-
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) =
71+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg}}) =
5072
StructuredMatrixStyle{Matrix}()
51-
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) =
73+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg}}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) =
5274
StructuredMatrixStyle{Matrix}()
5375

5476
# Make sure that `StructuredMatrixStyle{Matrix}` doesn't ever end up falling
@@ -95,6 +117,8 @@ structured_broadcast_alloc(bc, ::Type{UnitLowerTriangular}, ::Type{ElType}, n) w
95117
UnitLowerTriangular(Array{ElType}(undef, n, n))
96118
structured_broadcast_alloc(bc, ::Type{UnitUpperTriangular}, ::Type{ElType}, n) where {ElType} =
97119
UnitUpperTriangular(Array{ElType}(undef, n, n))
120+
structured_broadcast_alloc(bc, ::Type{UpperHessenberg}, ::Type{ElType}, n) where {ElType} =
121+
UpperHessenberg(Array{ElType}(undef, n, n))
98122
structured_broadcast_alloc(bc, ::Type{Matrix}, ::Type{ElType}, n) where {ElType} =
99123
Array{ElType}(undef, n, n)
100124

@@ -288,6 +312,18 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
288312
return dest
289313
end
290314

315+
function copyto!(dest::UpperHessenberg, bc::Broadcasted{<:StructuredMatrixStyle})
316+
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
317+
axs = axes(dest)
318+
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
319+
for j in axs[2]
320+
for i in 1:min(size(dest.data,1), j+1)
321+
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
322+
end
323+
end
324+
return dest
325+
end
326+
291327
# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
292328
function map(f, A::StructuredMatrix, Bs::StructuredMatrix...)
293329
sz = size(A)

test/structuredbroadcast.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ using .Main.SizedArrays
2121
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
2222
S = SymTridiagonal(rand(N), rand(N - 1))
2323
U = UpperTriangular(rand(N,N))
24+
UH = UpperHessenberg(rand(N,N))
2425
L = LowerTriangular(rand(N,N))
2526
M = Matrix(rand(N,N))
26-
structuredarrays = (D, B, T, U, L, M, S)
27+
structuredarrays = (D, B, T, U, UH, L, M, S)
2728
fstructuredarrays = map(Array, structuredarrays)
2829
for (X, fX) in zip(structuredarrays, fstructuredarrays)
2930
@test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX))
@@ -134,6 +135,7 @@ end
134135
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
135136
= LowerTriangular(rand(N,N))
136137
= UpperTriangular(rand(N,N))
138+
UH = UpperHessenberg(rand(N,N))
137139
M = Matrix(rand(N,N))
138140

139141
@test broadcast!(sin, copy(D), D) == Diagonal(sin.(D))
@@ -142,13 +144,15 @@ end
142144
@test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T))
143145
@test broadcast!(sin, copy(◣), ◣) == LowerTriangular(sin.(◣))
144146
@test broadcast!(sin, copy(◥), ◥) == UpperTriangular(sin.(◥))
147+
@test broadcast!(sin, copy(UH), UH) == UpperHessenberg(sin.(UH))
145148
@test broadcast!(sin, copy(M), M) == Matrix(sin.(M))
146149
@test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A))
147150
@test broadcast!(*, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
148151
@test broadcast!(*, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
149152
@test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
150153
@test broadcast!(*, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
151154
@test broadcast!(*, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
155+
@test broadcast!(*, copy(UH), UH, A) == UpperHessenberg(broadcast(*, UH, A))
152156
@test broadcast!(*, copy(M), M, A) == Matrix(broadcast(*, M, A))
153157

154158
@test_throws ArgumentError broadcast!(cos, copy(D), D) == Diagonal(sin.(D))
@@ -157,12 +161,14 @@ end
157161
@test_throws ArgumentError broadcast!(cos, copy(T), T) == Tridiagonal(sin.(T))
158162
@test_throws ArgumentError broadcast!(cos, copy(◣), ◣) == LowerTriangular(sin.(◣))
159163
@test_throws ArgumentError broadcast!(cos, copy(◥), ◥) == UpperTriangular(sin.(◥))
164+
@test_throws ArgumentError broadcast!(cos, copy(UH), UH)
160165
@test_throws ArgumentError broadcast!(+, copy(D), D, A) == Diagonal(broadcast(*, D, A))
161166
@test_throws ArgumentError broadcast!(+, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
162167
@test_throws ArgumentError broadcast!(+, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
163168
@test_throws ArgumentError broadcast!(+, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
164169
@test_throws ArgumentError broadcast!(+, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
165170
@test_throws ArgumentError broadcast!(+, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
171+
@test_throws ArgumentError broadcast!(+, copy(UH), UH, A)
166172
@test_throws ArgumentError broadcast!(*, copy(◥), ◣, 2)
167173
@test_throws ArgumentError broadcast!(*, copy(Bu), Bl, 2)
168174
end
@@ -177,8 +183,9 @@ end
177183
S = SymTridiagonal(rand(N), rand(N - 1))
178184
U = UpperTriangular(rand(N,N))
179185
L = LowerTriangular(rand(N,N))
186+
UH = UpperHessenberg(rand(N,N))
180187
M = Matrix(rand(N,N))
181-
structuredarrays = (M, D, B, T, S, U, L)
188+
structuredarrays = (M, D, B, T, S, U, L, UH)
182189
fstructuredarrays = map(Array, structuredarrays)
183190
for (X, fX) in zip(structuredarrays, fstructuredarrays)
184191
@test (Q = map(sin, X); typeof(Q) == typeof(X) && Q == map(sin, fX))

0 commit comments

Comments
 (0)