From 9b4259bd286516c0a8c39e0925f96f0de0a220df Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 1 May 2025 19:16:25 +0530 Subject: [PATCH] Use matrix size in structured broadcast --- src/structuredbroadcast.jl | 39 +++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/structuredbroadcast.jl b/src/structuredbroadcast.jl index a8eeaa94..76e33bc2 100644 --- a/src/structuredbroadcast.jl +++ b/src/structuredbroadcast.jl @@ -61,8 +61,8 @@ Broadcast.BroadcastStyle(T::StructuredMatrixStyle{Matrix}, ::StructuredMatrixSty Broadcast.BroadcastStyle(::StructuredMatrixStyle, ::StructuredMatrixStyle) = DefaultArrayStyle{2}() # And a definition akin to similar using the structured type: -structured_broadcast_alloc(bc, ::Type{Diagonal}, ::Type{ElType}, n) where {ElType} = - Diagonal(Array{ElType}(undef, n)) +structured_broadcast_alloc(bc, ::Type{Diagonal}, ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} = + Diagonal(Array{ElType}(undef, sz[1])) # Bidiagonal is tricky as we need to know if it's upper or lower. The promotion # system will return Tridiagonal when there's more than one Bidiagonal, but when # there's only one, we need to make figure out upper or lower @@ -75,7 +75,9 @@ find_uplo(a::Bidiagonal) = a.uplo find_uplo(a) = nothing find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nested(bc), init=nothing) -function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) where {ElType} +function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, + ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} + n = sz[1] uplo = n > 0 ? find_uplo(bc) : 'U' n1 = max(n - 1, 0) if count_structedmatrix(Bidiagonal, bc) > 1 && uplo == 'T' @@ -83,20 +85,23 @@ function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) w end return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n1), uplo) end -structured_broadcast_alloc(bc, ::Type{SymTridiagonal}, ::Type{ElType}, n) where {ElType} = +function structured_broadcast_alloc(bc, ::Type{SymTridiagonal}, + ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} + n = sz[1] SymTridiagonal(Array{ElType}(undef, n),Array{ElType}(undef, max(0,n-1))) -structured_broadcast_alloc(bc, ::Type{Tridiagonal}, ::Type{ElType}, n) where {ElType} = - Tridiagonal(Array{ElType}(undef, max(0,n-1)),Array{ElType}(undef, n),Array{ElType}(undef, max(0,n-1))) -structured_broadcast_alloc(bc, ::Type{LowerTriangular}, ::Type{ElType}, n) where {ElType} = - LowerTriangular(Array{ElType}(undef, n, n)) -structured_broadcast_alloc(bc, ::Type{UpperTriangular}, ::Type{ElType}, n) where {ElType} = - UpperTriangular(Array{ElType}(undef, n, n)) -structured_broadcast_alloc(bc, ::Type{UnitLowerTriangular}, ::Type{ElType}, n) where {ElType} = - UnitLowerTriangular(Array{ElType}(undef, n, n)) -structured_broadcast_alloc(bc, ::Type{UnitUpperTriangular}, ::Type{ElType}, n) where {ElType} = - UnitUpperTriangular(Array{ElType}(undef, n, n)) -structured_broadcast_alloc(bc, ::Type{Matrix}, ::Type{ElType}, n) where {ElType} = - Array{ElType}(undef, n, n) +end +function structured_broadcast_alloc(bc, ::Type{Tridiagonal}, + ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} + n = sz[1] + n1 = max(0,n-1) + Tridiagonal(Array{ElType}(undef, n1),Array{ElType}(undef, n),Array{ElType}(undef, n1)) +end +function structured_broadcast_alloc(bc, ::Type{T}, ::Type{ElType}, + sz::NTuple{2,Integer}) where {ElType,T<:UpperOrLowerTriangular} + T(Array{ElType}(undef, sz)) +end +structured_broadcast_alloc(bc, ::Type{Matrix}, ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} = + Array{ElType}(undef, sz) # A _very_ limited list of structure-preserving functions known at compile-time. This list is # derived from the formerly-implemented `broadcast` methods in 0.6. Note that this must @@ -172,7 +177,7 @@ function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) inds = axes(bc) fzerobc = fzeropreserving(bc) if isstructurepreserving(bc) || (fzerobc && !(T <: Union{UnitLowerTriangular,UnitUpperTriangular})) - return structured_broadcast_alloc(bc, T, ElType, length(inds[1])) + return structured_broadcast_alloc(bc, T, ElType, map(length, inds)) elseif fzerobc && T <: UnitLowerTriangular return similar(convert(Broadcasted{StructuredMatrixStyle{LowerTriangular}}, bc), ElType) elseif fzerobc && T <: UnitUpperTriangular