Skip to content

Commit 2e428a5

Browse files
authored
Reduce allocations in triangular tests (#1309)
We may reuse some variables in some cases, and avoid allocating at other places.
1 parent fd115f4 commit 2e428a5

File tree

1 file changed

+112
-94
lines changed

1 file changed

+112
-94
lines changed

test/testtriag.jl

Lines changed: 112 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,16 @@ function test_triangular(elty1_types)
3131
@test full!(copy(A1)) == A1
3232

3333
# similar
34-
@test isa(similar(A1), t1)
35-
@test eltype(similar(A1)) == elty1
36-
@test isa(similar(A1, Int), t1)
37-
@test eltype(similar(A1, Int)) == Int
34+
simA1 = similar(A1)
35+
@test isa(simA1, t1)
36+
@test eltype(simA1) == elty1
37+
simA1Int = similar(A1, Int)
38+
@test isa(simA1Int, t1)
39+
@test eltype(simA1Int) == Int
3840
@test isa(similar(A1, (3, 2)), Matrix{elty1})
3941
@test isa(similar(A1, Int, (3, 2)), Matrix{Int})
4042

4143
#copyto!
42-
simA1 = similar(A1)
4344
copyto!(simA1, A1)
4445
@test simA1 == A1
4546

@@ -119,49 +120,60 @@ function test_triangular(elty1_types)
119120
@test tril(A1, 0) == A1
120121
@test tril(A1, -1) == LowerTriangular(tril(M1, -1))
121122
@test tril(A1, 1) == t1(tril(tril(M1, 1)))
122-
@test tril(A1, -n - 2) == zeros(size(A1))
123+
A1tril = tril(A1, -n - 2)
124+
@test iszero(A1tril) && size(A1tril) == size(A1)
123125
@test tril(A1, n) == A1
124-
@test triu(A1, 0) == t1(diagm(0 => diag(A1)))
126+
@test triu(A1, 0) == t1(Diagonal(diagview(A1)))
125127
@test triu(A1, -1) == t1(tril(triu(A1.data, -1)))
126-
@test triu(A1, 1) == zeros(size(A1)) # or just @test iszero(triu(A1,1))?
128+
A1triu = triu(A1, 1)
129+
@test iszero(A1triu) && size(A1triu) == size(A1)
127130
@test triu(A1, -n) == A1
128-
@test triu(A1, n + 2) == zeros(size(A1))
131+
A1triu = triu(A1, n + 2)
132+
@test iszero(A1triu) && size(A1triu) == size(A1)
129133
else
130134
@test triu(A1, 0) == A1
131135
@test triu(A1, 1) == UpperTriangular(triu(M1, 1))
132136
@test triu(A1, -1) == t1(triu(triu(M1, -1)))
133137
@test triu(A1, -n) == A1
134-
@test triu(A1, n + 2) == zeros(size(A1))
135-
@test tril(A1, 0) == t1(diagm(0 => diag(A1)))
138+
A1triu = triu(A1, n + 2)
139+
@test iszero(A1triu) && size(A1triu) == size(A1)
140+
@test tril(A1, 0) == t1(Diagonal(diagview(A1)))
136141
@test tril(A1, 1) == t1(triu(tril(A1.data, 1)))
137-
@test tril(A1, -1) == zeros(size(A1)) # or just @test iszero(tril(A1,-1))?
138-
@test tril(A1, -n - 2) == zeros(size(A1))
142+
A1tril = tril(A1, -1)
143+
@test iszero(A1tril) && size(A1tril) == size(A1)
144+
@test iszero(A1tril) && size(A1tril) == size(A1)
139145
@test tril(A1, n) == A1
140146
end
141147

142148
# factorize
143149
@test factorize(A1) == A1
144150

145151
# [c]transpose[!] (test views as well, see issue #14317)
152+
# transpose
153+
@test copy(transpose(A1)) == transpose(M1)
154+
A1ctr = transpose!(copy(A1))
155+
@test A1ctr == transpose(A1)
156+
@test typeof(A1ctr).name == typeof(transpose(A1)).name
157+
# adjoint
158+
@test copy(A1') == M1'
159+
A1cadj = adjoint!(copy(A1))
160+
@test A1cadj == adjoint(A1)
161+
@test typeof(A1cadj).name == typeof(adjoint(A1)).name
162+
146163
let vrange = 1:n-1, viewA1 = t1(view(A1.data, vrange, vrange))
164+
MviewA1 = Matrix(viewA1)
147165
# transpose
148-
@test copy(transpose(A1)) == transpose(M1)
149-
@test copy(transpose(viewA1)) == transpose(Matrix(viewA1))
166+
@test copy(transpose(viewA1)) == transpose(MviewA1)
150167
# adjoint
151-
@test copy(A1') == M1'
152-
@test copy(viewA1') == Matrix(viewA1)'
168+
@test copy(viewA1') == MviewA1'
153169
# transpose!
154-
@test transpose!(copy(A1)) == transpose(A1)
155-
@test typeof(transpose!(copy(A1))).name == typeof(transpose(A1)).name
156170
@test transpose!(t1(view(copy(A1).data, vrange, vrange))) == transpose(viewA1)
157171
# adjoint!
158-
@test adjoint!(copy(A1)) == adjoint(A1)
159-
@test typeof(adjoint!(copy(A1))).name == typeof(adjoint(A1)).name
160172
@test adjoint!(t1(view(copy(A1).data, vrange, vrange))) == adjoint(viewA1)
161173
end
162174

163175
# diag
164-
@test diag(A1) == diag(M1)
176+
@test diag(A1) == diagview(M1)
165177

166178
# tr
167179
@test tr(A1)::elty1 == tr(M1)
@@ -173,29 +185,32 @@ function test_triangular(elty1_types)
173185

174186
# zero
175187
if A1 isa UpperTriangular || A1 isa LowerTriangular
176-
@test zero(A1) == zero(parent(A1))
188+
Z = zero(A1)
189+
@test iszero(Z) && size(Z) == size(A1)
177190
end
178191

179192
# Unary operations
180193
@test -A1 == -M1
181194

182195
# copy and copyto! (test views as well, see issue #14317)
196+
@test copy(A1) == M1
197+
A1trc = copy(transpose(A1))
198+
B = similar(A1trc)
199+
copyto!(B, A1trc)
200+
@test B == A1trc
201+
B = similar(A1)
202+
copyto!(B, A1)
203+
@test B == A1
183204
let vrange = 1:n-1, viewA1 = t1(view(A1.data, vrange, vrange))
184205
# copy
185-
@test copy(A1) == copy(M1)
186-
@test copy(viewA1) == copy(Matrix(viewA1))
206+
@test copy(viewA1) == Matrix(viewA1)
187207
# copyto!
188-
B = similar(A1)
189-
copyto!(B, A1)
190-
@test B == A1
191-
B = similar(copy(transpose(A1)))
192-
copyto!(B, copy(transpose(A1)))
193-
@test B == copy(transpose(A1))
194208
B = similar(viewA1)
195209
copyto!(B, viewA1)
196210
@test B == viewA1
197-
B = similar(copy(transpose(viewA1)))
198-
copyto!(B, copy(transpose(viewA1)))
211+
viewA1trc = copy(transpose(viewA1))
212+
B = similar(viewA1trc)
213+
copyto!(B, viewA1trc)
199214
@test B == transpose(viewA1)
200215
end
201216

@@ -217,15 +232,12 @@ function test_triangular(elty1_types)
217232
A1tmp = copy(A1)
218233
rmul!(A1tmp, cr)
219234
@test A1tmp == cr * A1
220-
A1tmp = copy(A1)
235+
A1tmp .= A1
221236
lmul!(cr, A1tmp)
222237
@test A1tmp == cr * A1
223-
A1tmp = copy(A1)
224238
A2tmp = unitt(A1)
225239
mul!(A1tmp, A2tmp, cr)
226240
@test A1tmp == cr * A2tmp
227-
A1tmp = copy(A1)
228-
A2tmp = unitt(A1)
229241
mul!(A1tmp, cr, A2tmp)
230242
@test A1tmp == cr * A2tmp
231243

@@ -237,15 +249,12 @@ function test_triangular(elty1_types)
237249
A1tmp = copy(A1)
238250
rmul!(A1tmp, ci)
239251
@test A1tmp == ci * A1
240-
A1tmp = copy(A1)
252+
A1tmp .= A1
241253
lmul!(ci, A1tmp)
242254
@test A1tmp == ci * A1
243-
A1tmp = copy(A1)
244255
A2tmp = unitt(A1)
245256
mul!(A1tmp, ci, A2tmp)
246257
@test A1tmp == ci * A2tmp
247-
A1tmp = copy(A1)
248-
A2tmp = unitt(A1)
249258
mul!(A1tmp, A2tmp, ci)
250259
@test A1tmp == A2tmp * ci
251260
end
@@ -265,19 +274,22 @@ function test_triangular(elty1_types)
265274
@test 0.5 \ A1 == 0.5 \ M1
266275

267276
# inversion
268-
@test inv(A1) inv(lu(M1))
269-
inv(M1) # issue #11298
270-
@test isa(inv(A1), t1)
277+
invA1 = inv(A1)
278+
M1lu = lu(M1)
279+
@test invA1 inv(M1lu)
280+
@test invA1 inv(M1) # issue #11298
281+
@test isa(invA1, t1)
271282
# make sure the call to LAPACK works right
272283
if elty1 <: BlasFloat
273-
@test LinearAlgebra.inv!(copy(A1)) inv(lu(M1))
284+
@test LinearAlgebra.inv!(copy(A1)) inv(M1lu)
274285
end
275286

276287
# Determinant
277-
@test det(A1) det(lu(M1)) atol = sqrt(eps(real(float(one(elty1))))) * n * n
278-
@test logdet(A1) logdet(lu(M1)) atol = sqrt(eps(real(float(one(elty1))))) * n * n
288+
M1lu = lu(M1lu)
289+
@test det(A1) det(M1lu) atol = sqrt(eps(real(float(one(elty1))))) * n * n
290+
@test logdet(A1) logdet(M1lu) atol = sqrt(eps(real(float(one(elty1))))) * n * n
279291
lada, ladb = logabsdet(A1)
280-
flada, fladb = logabsdet(lu(M1))
292+
flada, fladb = logabsdet(M1lu)
281293
@test lada flada atol = sqrt(eps(real(float(one(elty1))))) * n * n
282294
@test ladb fladb atol = sqrt(eps(real(float(one(elty1))))) * n * n
283295

@@ -340,7 +352,10 @@ function test_triangular(elty1_types)
340352
@test kron(A1, A2) == kron(M1, M2)
341353

342354
# Triangular-Triangular multiplication and division
343-
@test A1 * A2 M1 * M2
355+
A1_mul_A2 = A1 * A2
356+
A1_rdiv_A2 = A1 / A2
357+
A1_ldiv_A2 = A1 \ A2
358+
@test A1_mul_A2 M1 * M2
344359
@test transpose(A1) * A2 transpose(M1) * M2
345360
@test transpose(A1) * adjoint(A2) transpose(M1) * adjoint(M2)
346361
@test adjoint(A1) * transpose(A2) adjoint(M1) * transpose(M2)
@@ -349,35 +364,35 @@ function test_triangular(elty1_types)
349364
@test A1 * A2' M1 * M2'
350365
@test transpose(A1) * transpose(A2) transpose(M1) * transpose(M2)
351366
@test A1'A2' M1'M2'
352-
@test A1 / A2 M1 / M2
353-
@test A1 \ A2 M1 \ M2
367+
@test A1_rdiv_A2 M1 / M2
368+
@test A1_ldiv_A2 M1 \ M2
354369
if uplo1 === :U && uplo2 === :U
355370
if t1 === UnitUpperTriangular && t2 === UnitUpperTriangular
356-
@test A1 * A2 isa UnitUpperTriangular
357-
@test A1 / A2 isa UnitUpperTriangular
358-
elty1 == Int && elty2 == Int && @test eltype(A1 / A2) == Int
359-
@test A1 \ A2 isa UnitUpperTriangular
360-
elty1 == Int && elty2 == Int && @test eltype(A1 \ A2) == Int
371+
@test A1_mul_A2 isa UnitUpperTriangular
372+
@test A1_rdiv_A2 isa UnitUpperTriangular
373+
elty1 == Int && elty2 == Int && @test eltype(A1_rdiv_A2) == Int
374+
@test A1_ldiv_A2 isa UnitUpperTriangular
375+
elty1 == Int && elty2 == Int && @test eltype(A1_ldiv_A2) == Int
361376
else
362-
@test A1 * A2 isa UpperTriangular
363-
@test A1 / A2 isa UpperTriangular
364-
elty1 == Int && elty2 == Int && t2 === UnitUpperTriangular && @test eltype(A1 / A2) == Int
365-
@test A1 \ A2 isa UpperTriangular
366-
elty1 == Int && elty2 == Int && t1 === UnitUpperTriangular && @test eltype(A1 \ A2) == Int
377+
@test A1_mul_A2 isa UpperTriangular
378+
@test A1_rdiv_A2 isa UpperTriangular
379+
elty1 == Int && elty2 == Int && t2 === UnitUpperTriangular && @test eltype(A1_rdiv_A2) == Int
380+
@test A1_ldiv_A2 isa UpperTriangular
381+
elty1 == Int && elty2 == Int && t1 === UnitUpperTriangular && @test eltype(A1_ldiv_A2) == Int
367382
end
368383
elseif uplo1 === :L && uplo2 === :L
369384
if t1 === UnitLowerTriangular && t2 === UnitLowerTriangular
370-
@test A1 * A2 isa UnitLowerTriangular
371-
@test A1 / A2 isa UnitLowerTriangular
372-
elty1 == Int && elty2 == Int && @test eltype(A1 / A2) == Int
373-
@test A1 \ A2 isa UnitLowerTriangular
374-
elty1 == Int && elty2 == Int && @test eltype(A1 \ A2) == Int
385+
@test A1_mul_A2 isa UnitLowerTriangular
386+
@test A1_rdiv_A2 isa UnitLowerTriangular
387+
elty1 == Int && elty2 == Int && @test eltype(A1_rdiv_A2) == Int
388+
@test A1_ldiv_A2 isa UnitLowerTriangular
389+
elty1 == Int && elty2 == Int && @test eltype(A1_ldiv_A2) == Int
375390
else
376-
@test A1 * A2 isa LowerTriangular
377-
@test A1 / A2 isa LowerTriangular
378-
elty1 == Int && elty2 == Int && t2 === UnitLowerTriangular && @test eltype(A1 / A2) == Int
379-
@test A1 \ A2 isa LowerTriangular
380-
elty1 == Int && elty2 == Int && t1 === UnitLowerTriangular && @test eltype(A1 \ A2) == Int
391+
@test A1_mul_A2 isa LowerTriangular
392+
@test A1_rdiv_A2 isa LowerTriangular
393+
elty1 == Int && elty2 == Int && t2 === UnitLowerTriangular && @test eltype(A1_rdiv_A2) == Int
394+
@test A1_ldiv_A2 isa LowerTriangular
395+
elty1 == Int && elty2 == Int && t1 === UnitLowerTriangular && @test eltype(A1_ldiv_A2) == Int
381396
end
382397
end
383398
offsizeA = Matrix{Float64}(I, n + 1, n + 1)
@@ -426,46 +441,49 @@ function test_triangular(elty1_types)
426441
mul!(C, A1, Tri)
427442
@test C M1 * Tri
428443

444+
bcol1 = B[:, 1]
445+
429446
# Triangular-dense Matrix/vector multiplication
430-
@test A1 * B[:, 1] M1 * B[:, 1]
447+
@test A1 * bcol1 M1 * bcol1
431448
@test A1 * B M1 * B
432-
@test transpose(A1) * B[:, 1] transpose(M1) * B[:, 1]
433-
@test A1'B[:, 1] M1'B[:, 1]
449+
@test transpose(A1) * bcol1 transpose(M1) * bcol1
450+
@test A1'bcol1 M1'bcol1
434451
@test transpose(A1) * B transpose(M1) * B
435452
@test A1'B M1'B
436453
@test A1 * transpose(B) M1 * transpose(B)
437454
@test adjoint(A1) * transpose(B) M1' * transpose(B)
438455
@test transpose(A1) * adjoint(B) transpose(M1) * adjoint(B)
439456
@test A1 * B' M1 * B'
440457
@test B * A1 B * M1
441-
@test transpose(B[:, 1]) * A1 transpose(B[:, 1]) * M1
442-
@test B[:, 1]'A1 B[:, 1]'M1
458+
@test transpose(bcol1) * A1 transpose(bcol1) * M1
459+
@test bcol1'A1 bcol1'M1
443460
@test transpose(B) * A1 transpose(B) * M1
444461
@test transpose(B) * adjoint(A1) transpose(B) * M1'
445462
@test adjoint(B) * transpose(A1) adjoint(B) * transpose(M1)
446463
@test B'A1 B'M1
447464
@test B * transpose(A1) B * transpose(M1)
448465
@test B * A1' B * M1'
449-
@test transpose(B[:, 1]) * transpose(A1) transpose(B[:, 1]) * transpose(M1)
450-
@test B[:, 1]'A1' B[:, 1]'M1'
466+
@test transpose(bcol1) * transpose(A1) transpose(bcol1) * transpose(M1)
467+
@test bcol1'A1' bcol1'M1'
451468
@test transpose(B) * transpose(A1) transpose(B) * transpose(M1)
452469
@test B'A1' B'M1'
453470

454471
if eltyB == elty1
455-
@test mul!(similar(B), A1, B) M1 * B
456-
@test mul!(similar(B), A1, adjoint(B)) M1 * B'
457-
@test mul!(similar(B), A1, transpose(B)) M1 * transpose(B)
458-
@test mul!(similar(B), adjoint(A1), adjoint(B)) M1' * B'
459-
@test mul!(similar(B), transpose(A1), transpose(B)) transpose(M1) * transpose(B)
460-
@test mul!(similar(B), transpose(A1), adjoint(B)) transpose(M1) * B'
461-
@test mul!(similar(B), adjoint(A1), transpose(B)) M1' * transpose(B)
462-
@test mul!(similar(B), adjoint(A1), B) M1' * B
463-
@test mul!(similar(B), transpose(A1), B) transpose(M1) * B
472+
Bsim = similar(B)
473+
@test mul!(Bsim, A1, B) M1 * B
474+
@test mul!(Bsim, A1, adjoint(B)) M1 * B'
475+
@test mul!(Bsim, A1, transpose(B)) M1 * transpose(B)
476+
@test mul!(Bsim, adjoint(A1), adjoint(B)) M1' * B'
477+
@test mul!(Bsim, transpose(A1), transpose(B)) transpose(M1) * transpose(B)
478+
@test mul!(Bsim, transpose(A1), adjoint(B)) transpose(M1) * B'
479+
@test mul!(Bsim, adjoint(A1), transpose(B)) M1' * transpose(B)
480+
@test mul!(Bsim, adjoint(A1), B) M1' * B
481+
@test mul!(Bsim, transpose(A1), B) transpose(M1) * B
464482
# test also vector methods
465-
B1 = vec(B[1, :])
466-
@test mul!(similar(B1), A1, B1) M1 * B1
467-
@test mul!(similar(B1), adjoint(A1), B1) M1' * B1
468-
@test mul!(similar(B1), transpose(A1), B1) transpose(M1) * B1
483+
bcol1sim = similar(bcol1)
484+
@test mul!(bcol1sim, A1, bcol1) M1 * bcol1
485+
@test mul!(bcol1sim, adjoint(A1), bcol1) M1' * bcol1
486+
@test mul!(bcol1sim, transpose(A1), bcol1) transpose(M1) * bcol1
469487
end
470488
#error handling
471489
Ann, Bmm, bm = A1, Matrix{eltyB}(undef, n + 1, n + 1), Vector{eltyB}(undef, n + 1)
@@ -477,10 +495,10 @@ function test_triangular(elty1_types)
477495
@test_throws DimensionMismatch rmul!(Bmm, transpose(Ann))
478496

479497
# ... and division
480-
@test A1 \ B[:, 1] M1 \ B[:, 1]
498+
@test A1 \ bcol1 M1 \ bcol1
481499
@test A1 \ B M1 \ B
482-
@test transpose(A1) \ B[:, 1] transpose(M1) \ B[:, 1]
483-
@test A1' \ B[:, 1] M1' \ B[:, 1]
500+
@test transpose(A1) \ bcol1 transpose(M1) \ bcol1
501+
@test A1' \ bcol1 M1' \ bcol1
484502
@test transpose(A1) \ B transpose(M1) \ B
485503
@test A1' \ B M1' \ B
486504
@test A1 \ transpose(B) M1 \ transpose(B)

0 commit comments

Comments
 (0)