Skip to content

Commit 508e77f

Browse files
authored
Forward Diagonal rmul!/lmul! for adj/trans to parent (#1229)
Currently. `rmul!(A'::Adjoint, D::Diagonal)` loops over `CartesianIndices(A')`, which is not cache-friendly. In this PR, we use the fact that `A' = A' * D` may be rewritten as `A = D' * A`, and that `D' isa Diagonal` by design. The operation therefore becomes `lmul!(D'::Diagonal, A)`, which is cache-friendly. On master, ```julia julia> using LinearAlgebra, Chairmarks julia> D = Diagonal(rand(4000)); julia> A = rand(size(D)...); julia> @b (A', D) rmul!(_[1], _[2]) 103.276 ms ``` whereas, with this PR, ```julia julia> @b (A', D) rmul!(_[1], _[2]) 9.994 ms ```
1 parent f781708 commit 508e77f

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

src/diagonal.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,13 @@ function rmul!(A::AbstractMatrix, D::Diagonal)
360360
end
361361
return A
362362
end
363+
# A' = A' * D => A = D' * A
364+
# This uses the fact that D' is a Diagonal
365+
function rmul!(A::AdjOrTransAbsMat, D::Diagonal)
366+
f = wrapperop(A)
367+
lmul!(f(D), f(A))
368+
A
369+
end
363370
# T .= T * D
364371
function rmul!(T::Tridiagonal, D::Diagonal)
365372
matmul_size_check(size(T), size(D))
@@ -395,6 +402,13 @@ function lmul!(D::Diagonal, T::Tridiagonal)
395402
end
396403
return T
397404
end
405+
# A' = D * A' => A = A * D'
406+
# This uses the fact that D' is a Diagonal
407+
function lmul!(D::Diagonal, A::AdjOrTransAbsMat)
408+
f = wrapperop(A)
409+
rmul!(f(A), f(D))
410+
A
411+
end
398412

399413
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number)
400414
@inbounds for j in axes(B, 2)

test/diagonal.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,22 @@ end
12571257
end
12581258
end
12591259

1260+
@testset "rmul!/lmul! for adj/trans" begin
1261+
for T in (Float64, ComplexF64)
1262+
A = rand(T,5,4); B = similar(A)
1263+
for f in (adjoint, transpose)
1264+
D = Diagonal(rand(T, size(A,1)))
1265+
B .= A
1266+
rmul!(f(B), D)
1267+
@test f(B) == f(A) * D
1268+
D = Diagonal(rand(T, size(A,2)))
1269+
B .= A
1270+
lmul!(D, f(B))
1271+
@test f(B) == D * f(A)
1272+
end
1273+
end
1274+
end
1275+
12601276
struct SMatrix1{T} <: AbstractArray{T,2}
12611277
elt::T
12621278
end

0 commit comments

Comments
 (0)