Skip to content

Commit 165a37b

Browse files
authored
Preserve structured matrix types more often when broadcasting (#32762)
* Preserve structured matrix types more often when broadcasting The primary motivation here is exponentiation (fixes #32759). Previously `X .^ 2` would return a Matrix but `two=2; X .^ two` would return the same structured matrix type as `X`. This simply teaches LinearAlgebra a little bit more about broadcast so it can safely compute the zero-preservation property of broadcasts involving `Ref`s, which is sufficient to handle the literal pow argument list. * fixup negative exponents
1 parent 3f9faf3 commit 165a37b

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

stdlib/LinearAlgebra/src/structuredbroadcast.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,22 @@ isstructurepreserving(bc::Broadcasted) = isstructurepreserving(bc.f, bc.args...)
9191
isstructurepreserving(::Union{typeof(abs),typeof(big)}, ::StructuredMatrix) = true
9292
isstructurepreserving(::TypeFuncs, ::StructuredMatrix) = true
9393
isstructurepreserving(::TypeFuncs, ::Ref{<:Type}, ::StructuredMatrix) = true
94+
function isstructurepreserving(::typeof(Base.literal_pow), ::Ref{typeof(^)}, ::StructuredMatrix, ::Ref{Val{N}}) where N
95+
return N isa Integer && N > 0
96+
end
9497
isstructurepreserving(f, args...) = false
9598

9699
_iszero(n::Number) = iszero(n)
97100
_iszero(x) = x == 0
98101
fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && _iszero(v))
99-
# Very conservatively only allow Numbers and Types in this speculative zero-test pass
102+
# Like sparse matrices, we assume that the zero-preservation property of a broadcasted
103+
# expression is stable. We can test the zero-preservability by applying the function
104+
# in cases where all other arguments are known scalars against a zero from the structured
105+
# matrix. If any non-structured matrix argument is not a known scalar, we give up.
100106
fzero(x::Number) = x
101107
fzero(::Type{T}) where T = T
108+
fzero(r::Ref) = r[]
109+
fzero(t::Tuple{Any}) = t[1]
102110
fzero(S::StructuredMatrix) = zero(eltype(S))
103111
fzero(x) = missing
104112
function fzero(bc::Broadcast.Broadcasted)

stdlib/LinearAlgebra/test/structuredbroadcast.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ using Test, LinearAlgebra
2828
@test broadcast!(+, Z, fV, fA, X) == broadcast(+, fV, fA, fX)
2929
@test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX))
3030
@test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX)
31+
32+
@test X .* 2.0 == X .* (2.0,) == fX .* 2.0
33+
@test X .* 2.0 isa typeof(X)
34+
@test X .* (2.0,) isa typeof(X)
35+
@test isequal(X .* Inf, fX .* Inf)
36+
37+
two = 2
38+
@test X .^ 2 == X .^ (2,) == fX .^ 2 == X .^ two
39+
@test X .^ 2 isa typeof(X)
40+
@test X .^ (2,) isa typeof(X)
41+
@test X .^ two isa typeof(X)
42+
@test X .^ 0 == fX .^ 0
43+
@test X .^ -1 == fX .^ -1
44+
3145
for (Y, fY) in zip(structuredarrays, fstructuredarrays)
3246
@test broadcast(+, X, Y) == broadcast(+, fX, fY)
3347
@test broadcast!(+, Z, X, Y) == broadcast(+, fX, fY)

0 commit comments

Comments
 (0)