Skip to content

Commit c12e5bd

Browse files
ssikdar1simeonschaub
authored andcommitted
Fix Broadcasting of Bidiagonal (JuliaLang#35281)
1 parent 139cce4 commit c12e5bd

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

stdlib/LinearAlgebra/src/structuredbroadcast.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,21 @@ structured_broadcast_alloc(bc, ::Type{<:Diagonal}, ::Type{ElType}, n) where {ElT
6060
# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion
6161
# system will return Tridiagonal when there's more than one Bidiagonal, but when
6262
# there's only one, we need to make figure out upper or lower
63-
find_bidiagonal() = throw(ArgumentError("could not find Bidiagonal within broadcast expression"))
64-
find_bidiagonal(a::Bidiagonal, rest...) = a
65-
find_bidiagonal(bc::Broadcast.Broadcasted, rest...) = find_bidiagonal(find_bidiagonal(bc.args...), rest...)
66-
find_bidiagonal(x, rest...) = find_bidiagonal(rest...)
63+
merge_uplos(::Nothing, ::Nothing) = nothing
64+
merge_uplos(a, ::Nothing) = a
65+
merge_uplos(::Nothing, b) = b
66+
merge_uplos(a, b) = a == b ? a : 'T'
67+
68+
find_uplo(a::Bidiagonal) = a.uplo
69+
find_uplo(a) = nothing
70+
find_uplo(bc::Broadcasted) = mapreduce(find_uplo, merge_uplos, bc.args, init=nothing)
71+
6772
function structured_broadcast_alloc(bc, ::Type{<:Bidiagonal}, ::Type{ElType}, n) where {ElType}
68-
ex = find_bidiagonal(bc)
69-
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1), ex.uplo)
73+
uplo = find_uplo(bc)
74+
if uplo == 'T'
75+
return Tridiagonal(Array{ElType}(undef, n-1), Array{ElType}(undef, n), Array{ElType}(undef, n-1))
76+
end
77+
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1), uplo)
7078
end
7179
structured_broadcast_alloc(bc, ::Type{<:SymTridiagonal}, ::Type{ElType}, n) where {ElType} =
7280
SymTridiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1))

stdlib/LinearAlgebra/test/structuredbroadcast.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,49 @@ end
160160
@test L .+ UnitL .+ UnitU .+ U .+ D == L + UnitL + UnitU + U + D
161161
@test L .+ U .+ D .+ D .+ D .+ D == L + U + D + D + D + D
162162
end
163+
@testset "Broadcast Returned Types" begin
164+
# Issue 35245
165+
N = 3
166+
dV = rand(N)
167+
evu = rand(N-1)
168+
evl = rand(N-1)
169+
170+
Bu = Bidiagonal(dV, evu, :U)
171+
Bl = Bidiagonal(dV, evl, :L)
172+
T = Tridiagonal(evl, dV * 2, evu)
173+
174+
@test typeof(Bu .+ Bl) <: Tridiagonal
175+
@test typeof(Bl .+ Bu) <: Tridiagonal
176+
@test typeof(Bu .+ Bu) <: Bidiagonal
177+
@test typeof(Bl .+ Bl) <: Bidiagonal
178+
@test Bu .+ Bl == T
179+
@test Bl .+ Bu == T
180+
@test Bu .+ Bu == Bidiagonal(dV * 2, evu * 2, :U)
181+
@test Bl .+ Bl == Bidiagonal(dV * 2, evl * 2, :L)
182+
183+
184+
@test typeof(Bu .* Bl) <: Tridiagonal
185+
@test typeof(Bl .* Bu) <: Tridiagonal
186+
@test typeof(Bu .* Bu) <: Bidiagonal
187+
@test typeof(Bl .* Bl) <: Bidiagonal
188+
189+
@test Bu .* Bl == Tridiagonal(zeros(N-1), dV .* dV, zeros(N-1))
190+
@test Bl .* Bu == Tridiagonal(zeros(N-1), dV .* dV, zeros(N-1))
191+
@test Bu .* Bu == Bidiagonal(dV .* dV, evu .* evu, :U)
192+
@test Bl .* Bl == Bidiagonal(dV .* dV, evl .* evl, :L)
193+
194+
Bu2 = Bu .* 2
195+
@test typeof(Bu2) <: Bidiagonal && Bu2.uplo == 'U'
196+
Bu2 = 2 .* Bu
197+
@test typeof(Bu2) <: Bidiagonal && Bu2.uplo == 'U'
198+
Bl2 = Bl .* 2
199+
@test typeof(Bl2) <: Bidiagonal && Bl2.uplo == 'L'
200+
Bu2 = 2 .* Bl
201+
@test typeof(Bl2) <: Bidiagonal && Bl2.uplo == 'L'
202+
203+
# Example of Nested Brodacasts
204+
tmp = (1 .* 2) .* (Bidiagonal(1:3, 1:2, 'U') .* (3 .* 4)) .* (5 .* Bidiagonal(1:3, 1:2, 'L'))
205+
@test typeof(tmp) <: Tridiagonal
206+
207+
end
163208
end

0 commit comments

Comments
 (0)