Skip to content

Commit 5be3e27

Browse files
authored
Let muladd accept a more restricted set of arrays (#38250)
This adjusts #37065 to be much more cautious about what arrays it acts on: it calls mul! on StridedArrays, treats a few special types like Diagonal, UpperTriangular, and UniformScaling, and sends anything else to muladd(A,y,z) = A*y .+ z. However this broadcasting restricts the shape of z, mostly such that A*y .= z would work. That ensures you should get the same error from the mul!(::StridedMatrix, ...) method, as from the fallback broadcasting one. Both allow z of lower dimension than the existing muladd(x,y,z) = x*y+z. But x*y+z also allows z to have trailing dimensions, as long as they are of size 1. I made the broadcasting method allow these too, which I think should make this non-breaking. (I presume this is rarely used, and thus not worth sending to the fast method.) Structured matrices such as UpperTriangular should all go to x*y+z. Some combinations could be made more efficient but it gets complicated. Only the case of 3 diagonals is handled.
1 parent 4fa9e32 commit 5be3e27

File tree

4 files changed

+147
-52
lines changed

4 files changed

+147
-52
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,3 +752,7 @@ function logabsdet(A::Diagonal)
752752
mapreduce(x -> (log(abs(x)), sign(x)), ((d1, s1), (d2, s2)) -> (d1 + d2, s1 * s2),
753753
A.diag)
754754
end
755+
756+
function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal)
757+
Diagonal(A.diag .* B.diag .+ z.diag)
758+
end

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -201,34 +201,54 @@ julia> muladd(A, B, z)
201201
107.0 107.0
202202
```
203203
"""
204-
function Base.muladd(A::AbstractMatrix{TA}, y::AbstractVector{Ty}, z) where {TA, Ty}
205-
T = promote_type(TA, Ty, eltype(z))
204+
function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, AbstractArray})
205+
Ay = A * y
206+
for d in 1:ndims(Ay)
207+
# Same error as Ay .+= z would give, to match StridedMatrix method:
208+
size(z,d) > size(Ay,d) && throw(DimensionMismatch("array could not be broadcast to match destination"))
209+
end
210+
for d in ndims(Ay)+1:ndims(z)
211+
# Similar error to what Ay + z would give, to match (Any,Any,Any) method:
212+
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
213+
axes(z), ", must have singleton at dim ", d)))
214+
end
215+
Ay .+ z
216+
end
217+
218+
function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, AbstractArray})
219+
if size(z,1) > length(u) || size(z,2) > length(v)
220+
# Same error as (u*v) .+= z:
221+
throw(DimensionMismatch("array could not be broadcast to match destination"))
222+
end
223+
for d in 3:ndims(z)
224+
# Similar error to (u*v) + z:
225+
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
226+
axes(z), ", must have singleton at dim ", d)))
227+
end
228+
(u .* v) .+ z
229+
end
230+
231+
Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) =
232+
muladd(A', x', z')'
233+
Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) =
234+
transpose(muladd(transpose(A), transpose(x), transpose(z)))
235+
236+
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}
237+
238+
function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, y::AbstractVector{<:Number}, z::Union{Number, AbstractVector})
239+
T = promote_type(eltype(A), eltype(y), eltype(z))
206240
C = similar(A, T, axes(A,1))
207241
C .= z
208242
mul!(C, A, y, true, true)
209243
end
210244

211-
function Base.muladd(A::AbstractMatrix{TA}, B::AbstractMatrix{TB}, z) where {TA, TB}
212-
T = promote_type(TA, TB, eltype(z))
245+
function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, B::StridedMaybeAdjOrTransMat{<:Number}, z::Union{Number, AbstractVecOrMat})
246+
T = promote_type(eltype(A), eltype(B), eltype(z))
213247
C = similar(A, T, axes(A,1), axes(B,2))
214248
C .= z
215249
mul!(C, A, B, true, true)
216250
end
217251

218-
Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z) = muladd(A', x', z')'
219-
Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z) = transpose(muladd(transpose(A), transpose(x), transpose(z)))
220-
221-
function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z)
222-
ndims(z) > 2 && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
223-
(u .* v) .+ z
224-
end
225-
226-
function Base.muladd(u::AdjOrTransAbsVec, v::AbstractVector, z)
227-
uv = _dot_nonrecursive(u, v)
228-
ndims(z) > ndims(uv) && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
229-
uv .+ z
230-
end
231-
232252
"""
233253
mul!(Y, A, B) -> Y
234254

stdlib/LinearAlgebra/src/uniformscaling.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,12 @@ Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m)
483483
dot(x::AbstractVector, J::UniformScaling, y::AbstractVector) = dot(x, J.λ, y)
484484
dot(x::AbstractVector, a::Number, y::AbstractVector) = sum(t -> dot(t[1], a, t[2]), zip(x, y))
485485
dot(x::AbstractVector, a::Union{Real,Complex}, y::AbstractVector) = a*dot(x, y)
486+
487+
# muladd
488+
Base.muladd(A::UniformScaling, B::UniformScaling, z::UniformScaling) =
489+
UniformScaling(A.λ * B.λ + z.λ)
490+
Base.muladd(A::Union{Diagonal, UniformScaling}, B::Union{Diagonal, UniformScaling}, z::Union{Diagonal, UniformScaling}) =
491+
Diagonal(_diag_or_value(A) .* _diag_or_value(B) .+ _diag_or_value(z))
492+
493+
_diag_or_value(A::Diagonal) = A.diag
494+
_diag_or_value(A::UniformScaling) = A.λ

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 96 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -293,45 +293,107 @@ end
293293
end
294294

295295
@testset "muladd" begin
296-
A23 = reshape(1:6, 2,3)
296+
A23 = reshape(1:6, 2,3) .+ 0
297297
B34 = reshape(1:12, 3,4) .+ im
298298
u2 = [10,20]
299299
v3 = [3,5,7] .+ im
300300
w4 = [11,13,17,19im]
301301

302-
@test muladd(A23, B34, 100) == A23 * B34 .+ 100
303-
@test muladd(A23, B34, u2) == A23 * B34 .+ u2
304-
@test muladd(A23, B34, w4') == A23 * B34 .+ w4'
305-
@test_throws DimensionMismatch muladd(B34, A23, 1)
306-
@test_throws DimensionMismatch muladd(A23, B34, ones(2,4,1))
307-
308-
@test muladd(A23, v3, 100) == A23 * v3 .+ 100
309-
@test muladd(A23, v3, u2) == A23 * v3 .+ u2
310-
@test muladd(A23, v3, im) isa Vector{Complex{Int}}
311-
@test_throws DimensionMismatch muladd(A23, v3, ones(2,2))
312-
313-
@test muladd(v3', B34, 0) isa Adjoint
314-
@test muladd(v3', B34, 2im) == v3' * B34 .+ 2im
315-
@test muladd(v3', B34, w4') == v3' * B34 .+ w4'
316-
@test_throws DimensionMismatch muladd(v3', B34, ones(1,4))
317-
318-
@test muladd(u2, v3', 0) isa Matrix
319-
@test muladd(u2, v3', 99) == u2 * v3' .+ 99
320-
@test muladd(u2, v3', A23) == u2 * v3' .+ A23
321-
@test_throws DimensionMismatch muladd(u2, v3', ones(2,3,4))
322-
323-
@test muladd(u2', u2, 0) isa Number
324-
@test muladd(v3', v3, im) == dot(v3,v3) + im
325-
@test_throws DimensionMismatch muladd(v3', v3, [1])
326-
327-
vofm = [rand(1:9,2,2) for _ in 1:3]
328-
Mofm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]
329-
330-
@test muladd(vofm', vofm, vofm[1]) == vofm' * vofm .+ vofm[1] # inner
331-
@test muladd(vofm, vofm', Mofm) == vofm * vofm' .+ Mofm # outer
332-
@test muladd(vofm', Mofm, vofm') == vofm' * Mofm .+ vofm' # bra-mat
333-
@test muladd(Mofm, Mofm, vofm) == Mofm * Mofm .+ vofm # mat-mat
334-
@test_broken muladd(Mofm, vofm, vofm) == Mofm * vofm .+ vofm # mat-vec
302+
@testset "matrix-matrix" begin
303+
@test muladd(A23, B34, 0) == A23 * B34
304+
@test muladd(A23, B34, 100) == A23 * B34 .+ 100
305+
@test muladd(A23, B34, u2) == A23 * B34 .+ u2
306+
@test muladd(A23, B34, w4') == A23 * B34 .+ w4'
307+
@test_throws DimensionMismatch muladd(B34, A23, 1)
308+
@test muladd(ones(1,3), ones(3,4), ones(1,4)) == fill(4.0,1,4)
309+
@test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(9,4))
310+
311+
# broadcasting fallback method allows trailing dims
312+
@test muladd(A23, B34, ones(2,4,1)) == A23 * B34 + ones(2,4,1)
313+
@test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(9,4,1))
314+
@test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(1,4,9))
315+
# and catches z::Array{T,0}
316+
@test muladd(A23, B34, fill(0)) == A23 * B34
317+
end
318+
@testset "matrix-vector" begin
319+
@test muladd(A23, v3, 0) == A23 * v3
320+
@test muladd(A23, v3, 100) == A23 * v3 .+ 100
321+
@test muladd(A23, v3, u2) == A23 * v3 .+ u2
322+
@test muladd(A23, v3, im) isa Vector{Complex{Int}}
323+
@test muladd(ones(1,3), ones(3), ones(1)) == [4]
324+
@test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(7))
325+
326+
# fallback
327+
@test muladd(A23, v3, ones(2,1,1)) == A23 * v3 + ones(2,1,1)
328+
@test_throws DimensionMismatch muladd(A23, v3, ones(2,2))
329+
@test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(7,1))
330+
@test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(1,7))
331+
@test muladd(A23, v3, fill(0)) == A23 * v3
332+
end
333+
@testset "adjoint-matrix" begin
334+
@test muladd(v3', B34, 0) isa Adjoint
335+
@test muladd(v3', B34, 2im) == v3' * B34 .+ 2im
336+
@test muladd(v3', B34, w4') == v3' * B34 .+ w4'
337+
338+
# via fallback
339+
@test muladd(v3', B34, ones(1,4)) == (B34' * v3 + ones(4,1))'
340+
@test_throws DimensionMismatch muladd(v3', B34, ones(7,4))
341+
@test_throws DimensionMismatch muladd(v3', B34, ones(1,4,7))
342+
@test muladd(v3', B34, fill(0)) == v3' * B34 # does not make an Adjoint
343+
end
344+
@testset "vector-adjoint" begin
345+
@test muladd(u2, v3', 0) isa Matrix
346+
@test muladd(u2, v3', 99) == u2 * v3' .+ 99
347+
@test muladd(u2, v3', A23) == u2 * v3' .+ A23
348+
349+
@test muladd(u2, v3', ones(2,3,1)) == u2 * v3' + ones(2,3,1)
350+
@test_throws DimensionMismatch muladd(u2, v3', ones(2,3,4))
351+
@test_throws DimensionMismatch muladd([1], v3', ones(7,3))
352+
@test muladd(u2, v3', fill(0)) == u2 * v3'
353+
end
354+
@testset "dot" begin # all use muladd(::Any, ::Any, ::Any)
355+
@test muladd(u2', u2, 0) isa Number
356+
@test muladd(v3', v3, im) == dot(v3,v3) + im
357+
@test muladd(u2', u2, [1]) == [dot(u2,u2) + 1]
358+
@test_throws DimensionMismatch muladd(u2', u2, [1,1]) == [dot(u2,u2) + 1]
359+
@test muladd(u2', u2, fill(0)) == dot(u2,u2)
360+
end
361+
@testset "arrays of arrays" begin
362+
vofm = [rand(1:9,2,2) for _ in 1:3]
363+
Mofm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]
364+
365+
@test muladd(vofm', vofm, vofm[1]) == vofm' * vofm .+ vofm[1] # inner
366+
@test muladd(vofm, vofm', Mofm) == vofm * vofm' .+ Mofm # outer
367+
@test muladd(vofm', Mofm, vofm') == vofm' * Mofm .+ vofm' # bra-mat
368+
@test muladd(Mofm, Mofm, vofm) == Mofm * Mofm .+ vofm # mat-mat
369+
@test muladd(Mofm, vofm, vofm) == Mofm * vofm .+ vofm # mat-vec
370+
end
371+
end
372+
373+
@testset "muladd & structured matrices" begin
374+
A33 = reshape(1:9, 3,3) .+ im
375+
v3 = [3,5,7im]
376+
377+
# no special treatment
378+
@test muladd(Symmetric(A33), Symmetric(A33), 1) == Symmetric(A33) * Symmetric(A33) .+ 1
379+
@test muladd(Hermitian(A33), Hermitian(A33), v3) == Hermitian(A33) * Hermitian(A33) .+ v3
380+
@test muladd(adjoint(A33), transpose(A33), A33) == A33' * transpose(A33) .+ A33
381+
382+
u1 = muladd(UpperTriangular(A33), UpperTriangular(A33), Diagonal(v3))
383+
@test u1 isa UpperTriangular
384+
@test u1 == UpperTriangular(A33) * UpperTriangular(A33) + Diagonal(v3)
385+
386+
# diagonal
387+
@test muladd(Diagonal(v3), Diagonal(A33), Diagonal(v3)).diag == ([1,5,9] .+ im .+ 1) .* v3
388+
389+
# uniformscaling
390+
@test muladd(Diagonal(v3), I, I).diag == v3 .+ 1
391+
@test muladd(2*I, 3*I, I).λ == 7
392+
@test muladd(A33, A33', I) == A33 * A33' + I
393+
394+
# https://github.com/JuliaLang/julia/issues/38426
395+
@test @evalpoly(A33, 1.0*I, 1.0*I) == I + A33
396+
@test @evalpoly(A33, 1.0*I, 1.0*I, 1.0*I) == I + A33 + A33^2
335397
end
336398

337399
# issue #6450

0 commit comments

Comments
 (0)