Skip to content

Commit b1f9cb3

Browse files
authored
Reduce allocations in diagonal tests (#1377)
Some of these allocations are unnecessary, and we may re-use pre-allocated arrays instead.
1 parent 0768580 commit b1f9cb3

File tree

1 file changed

+106
-85
lines changed

1 file changed

+106
-85
lines changed

test/diagonal.jl

Lines changed: 106 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
9292
@test typeof(convert(Diagonal{ComplexF32},D)) <: Diagonal{ComplexF32}
9393
@test typeof(convert(AbstractMatrix{ComplexF32},D)) <: Diagonal{ComplexF32}
9494

95-
@test Array(real(D)) == real(M)
96-
@test Array(abs.(D)) == abs.(M)
97-
@test Array(imag(D)) == imag(M)
95+
@test convert(Array, real(D)) == real(M)
96+
@test convert(Array, abs.(D)) == abs.(M)
97+
@test convert(Array, imag(D)) == imag(M)
9898

9999
@test parent(D) == dd
100100
@test D[1,1] == dd[1]
@@ -170,8 +170,8 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
170170
@test D*v DM*v atol=n*eps(relty)*(1+(elty<:Complex))
171171
@test D*U DM*U atol=n^2*eps(relty)*(1+(elty<:Complex))
172172

173-
@test transpose(U)*D transpose(U)*Array(D)
174-
@test U'*D U'*Array(D)
173+
@test transpose(U)*D transpose(U)*M
174+
@test U'*D U'*M
175175

176176
if relty != BigFloat
177177
atol_two = 2n^2 * eps(relty) * (1 + (elty <: Complex))
@@ -206,12 +206,12 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
206206
@test_throws DimensionMismatch ldiv!(D, fill(elty(1), n + 1))
207207
@test_throws SingularException ldiv!(Diagonal(zeros(relty, n)), copy(v))
208208
b = rand(elty, n, n)
209-
@test ldiv!(D, copy(b)) Array(D)\Array(b)
209+
@test ldiv!(D, copy(b)) M\b
210210
@test_throws SingularException ldiv!(Diagonal(zeros(elty, n)), copy(b))
211211
b = view(rand(elty, n), Vector(1:n))
212212
b2 = copy(b)
213213
c = ldiv!(D, b)
214-
d = Array(D)\b2
214+
d = M\b2
215215
@test c d
216216
@test_throws SingularException ldiv!(Diagonal(zeros(elty, n)), b)
217217
b = rand(elty, n+1, n+1)
@@ -234,9 +234,9 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
234234
@test Array(D*a) DM*a
235235
@test Array(D/a) DM/a
236236
if elty <: Real
237-
@test Array(abs.(D)^a) abs.(DM)^a
237+
@test convert(Array, abs.(D)^a) abs.(DM)^a
238238
else
239-
@test Array(D^a) DM^a rtol=max(eps(relty), 1e-15) # TODO: improve precision
239+
@test convert(Array, D^a) DM^a rtol=max(eps(relty), 1e-15) # TODO: improve precision
240240
end
241241
@test Diagonal(1:100)^2 == Diagonal((1:100).^2)
242242
p = 3
@@ -248,17 +248,17 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
248248

249249
if relty <: BlasFloat
250250
for b in (rand(elty,n,n), rand(elty,n))
251-
@test lmul!(copy(D), copy(b)) Array(D)*Array(b)
252-
@test lmul!(transpose(copy(D)), copy(b)) transpose(Array(D))*Array(b)
253-
@test lmul!(adjoint(copy(D)), copy(b)) Array(D)'*Array(b)
251+
@test lmul!(copy(D), copy(b)) M*b
252+
@test lmul!(transpose(copy(D)), copy(b)) transpose(M)*b
253+
@test lmul!(adjoint(copy(D)), copy(b)) M'*b
254254
end
255255
end
256256

257257
#a few missing mults
258258
bd = Bidiagonal(D2)
259-
@test D*transpose(D2) Array(D)*transpose(Array(D2))
260-
@test D2*transpose(D) Array(D2)*transpose(Array(D))
261-
@test D2*D' Array(D2)*Array(D)'
259+
@test D*transpose(D2) M*transpose(DM2)
260+
@test D2*transpose(D) DM2*transpose(M)
261+
@test D2*D' DM2*M'
262262

263263
#division of two Diagonals
264264
@test D/D2 Diagonal(D.diag./D2.diag)
@@ -273,33 +273,37 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
273273
A = rand(elty, n, n)
274274
Asym = Symmetric(A + transpose(A), :U)
275275
Aherm = Hermitian(A + adjoint(A), :U)
276+
Msym = Array(Asym)
277+
Mherm = Array(Aherm)
276278
for op in (+, -)
277279
@test op(Asym, D) isa Symmetric
278-
@test Array(op(Asym, D)) Array(Symmetric(op(Array(Asym), Array(D))))
280+
@test convert(Array, op(Asym, D)) Array(Symmetric(op(Msym, M)))
279281
@test op(D, Asym) isa Symmetric
280-
@test Array(op(D, Asym)) Array(Symmetric(op(Array(D), Array(Asym))))
282+
@test convert(Array, op(D, Asym)) Array(Symmetric(op(M, Msym)))
281283
if !(elty <: Real)
282284
Dr = real(D)
285+
Mr = Array(Dr)
283286
@test op(Aherm, Dr) isa Hermitian
284-
@test Array(op(Aherm, Dr)) Array(Hermitian(op(Array(Aherm), Array(Dr))))
287+
@test convert(Array, op(Aherm, Dr)) Array(Hermitian(op(Mherm, Mr)))
285288
@test op(Dr, Aherm) isa Hermitian
286-
@test Array(op(Dr, Aherm)) Array(Hermitian(op(Array(Dr), Array(Aherm))))
289+
@test convert(Array, op(Dr, Aherm)) Array(Hermitian(op(Mr, Mherm)))
287290
end
288291
end
289-
@test Array(D*transpose(Asym)) Array(D) * Array(transpose(Asym))
290-
@test Array(D*adjoint(Asym)) Array(D) * Array(adjoint(Asym))
291-
@test Array(D*transpose(Aherm)) Array(D) * Array(transpose(Aherm))
292-
@test Array(D*adjoint(Aherm)) Array(D) * Array(adjoint(Aherm))
293-
@test Array(transpose(Asym)*transpose(D)) Array(transpose(Asym)) * Array(transpose(D))
294-
@test Array(transpose(D)*transpose(Asym)) Array(transpose(D)) * Array(transpose(Asym))
295-
@test Array(adjoint(Aherm)*adjoint(D)) Array(adjoint(Aherm)) * Array(adjoint(D))
296-
@test Array(adjoint(D)*adjoint(Aherm)) Array(adjoint(D)) * Array(adjoint(Aherm))
292+
Msym = Array(Asym)
293+
@test convert(Array, D*transpose(Asym)) M * convert(Array, transpose(Msym))
294+
@test convert(Array, D*adjoint(Asym)) M * convert(Array, adjoint(Asym))
295+
@test convert(Array, D*transpose(Aherm)) M * convert(Array, transpose(Aherm))
296+
@test convert(Array, D*adjoint(Aherm)) M * convert(Array, adjoint(Aherm))
297+
@test convert(Array, Asym*transpose(D)) Msym * convert(Array, transpose(D))
298+
@test convert(Array, transpose(D)*Asym) convert(Array, transpose(D)) * Msym
299+
@test convert(Array, adjoint(Aherm)*adjoint(D)) convert(Array, adjoint(Aherm)) * convert(Array, adjoint(D))
300+
@test convert(Array, adjoint(D)*adjoint(Aherm)) convert(Array, adjoint(D)) * convert(Array, adjoint(Aherm))
297301

298302
# Performance specialisations for A*_mul_B!
299303
vvv = similar(vv)
300-
@test (r = Matrix(D) * vv ; mul!(vvv, D, vv) r vvv)
301-
@test (r = Matrix(D)' * vv ; mul!(vvv, adjoint(D), vv) r vvv)
302-
@test (r = transpose(Matrix(D)) * vv ; mul!(vvv, transpose(D), vv) r vvv)
304+
@test (r = M * vv ; mul!(vvv, D, vv) r vvv)
305+
@test (r = M' * vv ; mul!(vvv, adjoint(D), vv) r vvv)
306+
@test (r = transpose(M) * vv ; mul!(vvv, transpose(D), vv) r vvv)
303307

304308
UUU = similar(UU)
305309
for transformA in (identity, adjoint, transpose)
@@ -311,55 +315,62 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
311315

312316
alpha = elty(randn()) # randn(elty) does not work with BigFloat
313317
beta = elty(randn())
314-
@test begin
318+
@testset begin
315319
vvv = similar(vv)
316320
vvv .= randn(size(vvv)) # randn!(vvv) does not work with BigFloat
317-
r = alpha * Matrix(D) * vv + beta * vvv
318-
mul!(vvv, D, vv, alpha, beta) r vvv
321+
r = alpha * M * vv + beta * vvv
322+
@test mul!(vvv, D, vv, alpha, beta) === vvv
323+
@test r vvv
319324
end
320-
@test begin
325+
@testset begin
321326
vvv = similar(vv)
322327
vvv .= randn(size(vvv)) # randn!(vvv) does not work with BigFloat
323-
r = alpha * Matrix(D)' * vv + beta * vvv
324-
mul!(vvv, adjoint(D), vv, alpha, beta) r vvv
328+
r = alpha * M' * vv + beta * vvv
329+
@test mul!(vvv, adjoint(D), vv, alpha, beta) === vvv
330+
@test r vvv
325331
end
326-
@test begin
332+
@testset begin
327333
vvv = similar(vv)
328334
vvv .= randn(size(vvv)) # randn!(vvv) does not work with BigFloat
329-
r = alpha * transpose(Matrix(D)) * vv + beta * vvv
330-
mul!(vvv, transpose(D), vv, alpha, beta) r vvv
335+
r = alpha * transpose(M) * vv + beta * vvv
336+
@test mul!(vvv, transpose(D), vv, alpha, beta) === vvv
337+
@test r vvv
331338
end
332339

333-
@test begin
340+
@testset begin
334341
UUU = similar(UU)
335342
UUU .= randn(size(UUU)) # randn!(UUU) does not work with BigFloat
336-
r = alpha * Matrix(D) * UU + beta * UUU
337-
mul!(UUU, D, UU, alpha, beta) r UUU
343+
r = alpha * M * UU + beta * UUU
344+
@test mul!(UUU, D, UU, alpha, beta) === UUU
345+
@test r UUU
338346
end
339-
@test begin
347+
@testset begin
340348
UUU = similar(UU)
341349
UUU .= randn(size(UUU)) # randn!(UUU) does not work with BigFloat
342-
r = alpha * Matrix(D)' * UU + beta * UUU
343-
mul!(UUU, adjoint(D), UU, alpha, beta) r UUU
350+
r = alpha * M' * UU + beta * UUU
351+
@test mul!(UUU, adjoint(D), UU, alpha, beta) === UUU
352+
@test r UUU
344353
end
345-
@test begin
354+
@testset begin
346355
UUU = similar(UU)
347356
UUU .= randn(size(UUU)) # randn!(UUU) does not work with BigFloat
348-
r = alpha * transpose(Matrix(D)) * UU + beta * UUU
349-
mul!(UUU, transpose(D), UU, alpha, beta) r UUU
357+
r = alpha * transpose(M) * UU + beta * UUU
358+
@test mul!(UUU, transpose(D), UU, alpha, beta) === UUU
359+
@test r UUU
350360
end
351361

352362
# make sure that mul!(A, {Adj|Trans}(B)) works with B as a Diagonal
353363
VV = Array(D)
354-
DD = copy(D)
355-
r = VV * Matrix(D)
356-
@test Array(rmul!(VV, DD)) r Array(D)*Array(D)
357-
DD = copy(D)
358-
r = VV * transpose(Array(D))
359-
@test Array(rmul!(VV, transpose(DD))) r
360-
DD = copy(D)
361-
r = VV * Array(D)'
362-
@test Array(rmul!(VV, adjoint(DD))) r
364+
r = VV * M
365+
@test rmul!(VV, D) r M*M
366+
if transpose(D) !== D
367+
r = VV * transpose(M)
368+
@test rmul!(VV, transpose(D)) r
369+
end
370+
if adjoint(D) !== D
371+
r = VV * M'
372+
@test rmul!(VV, adjoint(D)) r
373+
end
363374

364375
# kron
365376
D3 = Diagonal(convert(Vector{elty}, rand(n÷2)))
@@ -537,16 +548,17 @@ Base.size(x::SimpleVector) = size(x.vec)
537548

538549
@testset "kron (issue #46456)" for repr in Any[identity, SimpleVector]
539550
A = Diagonal(repr(randn(10)))
551+
M = Array(A)
540552
BL = Bidiagonal(repr(randn(10)), repr(randn(9)), :L)
541553
BU = Bidiagonal(repr(randn(10)), repr(randn(9)), :U)
542554
C = SymTridiagonal(repr(randn(10)), repr(randn(9)))
543555
Cl = SymTridiagonal(repr(randn(10)), repr(randn(10)))
544556
D = Tridiagonal(repr(randn(9)), repr(randn(10)), repr(randn(9)))
545-
@test kron(A, BL)::Bidiagonal == kron(Array(A), Array(BL))
546-
@test kron(A, BU)::Bidiagonal == kron(Array(A), Array(BU))
547-
@test kron(A, C)::SymTridiagonal == kron(Array(A), Array(C))
548-
@test kron(A, Cl)::SymTridiagonal == kron(Array(A), Array(Cl))
549-
@test kron(A, D)::Tridiagonal == kron(Array(A), Array(D))
557+
@test kron(A, BL)::Bidiagonal == kron(M, Array(BL))
558+
@test kron(A, BU)::Bidiagonal == kron(M, Array(BU))
559+
@test kron(A, C)::SymTridiagonal == kron(M, Array(C))
560+
@test kron(A, Cl)::SymTridiagonal == kron(M, Array(Cl))
561+
@test kron(A, D)::Tridiagonal == kron(M, Array(D))
550562
end
551563

552564
@testset "svdvals and eigvals (#11120/#11247)" begin
@@ -619,9 +631,10 @@ end
619631

620632
@testset "Test reverse" begin
621633
D = Diagonal(randn(5))
622-
@test reverse(D, dims=1) == reverse(Matrix(D), dims=1)
623-
@test reverse(D, dims=2) == reverse(Matrix(D), dims=2)
624-
@test reverse(D)::Diagonal == reverse(Matrix(D))
634+
M = Matrix(D)
635+
@test reverse(D, dims=1) == reverse(M, dims=1)
636+
@test reverse(D, dims=2) == reverse(M, dims=2)
637+
@test reverse(D)::Diagonal == reverse(M)
625638
end
626639

627640
@testset "inverse" begin
@@ -637,8 +650,9 @@ end
637650
@testset "pseudoinverse" begin
638651
for d in Any[randn(n), zeros(n), Int[], [0, 2, 0.003], [0im, 1+2im, 0.003im], [0//1, 2//1, 3//100], [0//1, 1//1+2im, 3im//100]]
639652
D = Diagonal(d)
640-
@test pinv(D) pinv(Array(D))
641-
@test pinv(D, 1.0e-2) pinv(Array(D), 1.0e-2)
653+
M = Array(D)
654+
@test pinv(D) pinv(M)
655+
@test pinv(D, 1.0e-2) pinv(M, 1.0e-2)
642656
end
643657
end
644658

@@ -654,51 +668,54 @@ end
654668
@test Matrix(1.0I, 5, 5) \ Diagonal(fill(1.,5)) == Matrix(I, 5, 5)
655669

656670
@testset "Triangular and Diagonal" begin
657-
function _test_matrix(type)
671+
function _randomarray(type, ::Val{N} = Val(2)) where {N}
672+
sz = ntuple(_->5, N)
658673
if type == Int
659-
return rand(1:9, 5, 5)
674+
return rand(1:9, sz...)
660675
else
661-
return randn(type, 5, 5)
676+
return randn(type, sz...)
662677
end
663678
end
664679
types = (Float64, Int, ComplexF64)
665680
for ta in types
666-
D = Diagonal(_test_matrix(ta))
681+
D = Diagonal(_randomarray(ta, Val(1)))
682+
M = Matrix(D)
667683
for tb in types
668-
B = _test_matrix(tb)
684+
B = _randomarray(tb, Val(2))
669685
Tmats = (LowerTriangular(B), UnitLowerTriangular(B), UpperTriangular(B), UnitUpperTriangular(B))
670686
restypes = (LowerTriangular, LowerTriangular, UpperTriangular, UpperTriangular)
671687
for (T, rtype) in zip(Tmats, restypes)
672688
adjtype = (rtype == LowerTriangular) ? UpperTriangular : LowerTriangular
673689

674690
# Triangular * Diagonal
675691
R = T * D
676-
@test R Array(T) * Array(D)
692+
TA = Array(T)
693+
@test R TA * M
677694
@test isa(R, rtype)
678695

679696
# Diagonal * Triangular
680697
R = D * T
681-
@test R Array(D) * Array(T)
698+
@test R M * TA
682699
@test isa(R, rtype)
683700

684701
# Adjoint of Triangular * Diagonal
685702
R = T' * D
686-
@test R Array(T)' * Array(D)
703+
@test R TA' * M
687704
@test isa(R, adjtype)
688705

689706
# Diagonal * Adjoint of Triangular
690707
R = D * T'
691-
@test R Array(D) * Array(T)'
708+
@test R M * TA'
692709
@test isa(R, adjtype)
693710

694711
# Transpose of Triangular * Diagonal
695712
R = transpose(T) * D
696-
@test R transpose(Array(T)) * Array(D)
713+
@test R transpose(TA) * M
697714
@test isa(R, adjtype)
698715

699716
# Diagonal * Transpose of Triangular
700717
R = D * transpose(T)
701-
@test R Array(D) * transpose(Array(T))
718+
@test R M * transpose(TA)
702719
@test isa(R, adjtype)
703720
end
704721
end
@@ -1325,7 +1342,7 @@ end
13251342
end
13261343

13271344
@testset "diagonal triple multiplication (#49005)" begin
1328-
n = 10
1345+
local n = 10
13291346
@test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n))) isa Diagonal
13301347
@test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n+1))))
13311348
@test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n+1), Diagonal(ones(n+1))))
@@ -1441,10 +1458,12 @@ end
14411458
for p in ([1 2; 3 4], [1 2+im; 2-im 4+2im])
14421459
m = SizedArrays.SizedArray{(2,2)}(p)
14431460
D = Diagonal(fill(m, 2))
1461+
M = Matrix(D)
14441462
for T in (Symmetric, Hermitian)
14451463
S = T(fill(m, 2, 2))
1446-
@test D + S == Array(D) + Array(S)
1447-
@test S + D == Array(S) + Array(D)
1464+
SA = Array(S)
1465+
@test D + S == M + SA
1466+
@test S + D == SA + M
14481467
end
14491468
end
14501469
end
@@ -1456,12 +1475,14 @@ end
14561475

14571476
@testset "zeros in kron with block matrices" begin
14581477
D = Diagonal(1:4)
1478+
M = Matrix(D)
14591479
B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2)
1460-
@test kron(D, B) == kron(Array(D), B)
1461-
@test kron(B, D) == kron(B, Array(D))
1480+
@test kron(D, B) == kron(M, B)
1481+
@test kron(B, D) == kron(B, M)
14621482
D2 = Diagonal([ones(2,2), ones(3,3)])
1463-
@test kron(D, D2) == kron(D, Array{eltype(D2)}(D2))
1464-
@test kron(D2, D) == kron(Array{eltype(D2)}(D2), D)
1483+
M2 = Array{eltype(D2)}(D2)
1484+
@test kron(D, D2) == kron(D, M2)
1485+
@test kron(D2, D) == kron(M2, D)
14651486
end
14661487

14671488
@testset "opnorms" begin

0 commit comments

Comments
 (0)