Skip to content

Commit 57785c7

Browse files
authored
More resizing for truncating return values from LAPACK (#1190)
1 parent b464203 commit 57785c7

File tree

2 files changed

+40
-20
lines changed

2 files changed

+40
-20
lines changed

src/lapack.jl

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,7 +2835,7 @@ for (orglq, orgqr, orgql, orgrq, ormlq, ormqr, ormql, ormrq, gemqrt, elty) in
28352835
end
28362836
end
28372837
if n < size(A,2)
2838-
A[:,1:n]
2838+
reshape(resize!(vec(A), m * n), m, n)
28392839
else
28402840
A
28412841
end
@@ -2871,7 +2871,7 @@ for (orglq, orgqr, orgql, orgrq, ormlq, ormqr, ormql, ormrq, gemqrt, elty) in
28712871
end
28722872
end
28732873
if n < size(A,2)
2874-
A[:,1:n]
2874+
reshape(resize!(vec(A), m * n), m, n)
28752875
else
28762876
A
28772877
end
@@ -3736,22 +3736,24 @@ for (trcon, trevc, trrfs, elty) in
37363736
work, info, 1, 1)
37373737
chklapackerror(info[])
37383738

3739+
VLn = size(VL, 1)
3740+
VRn = size(VR, 1)
37393741
#Decide what exactly to return
37403742
if howmny == 'S' #compute selected eigenvectors
37413743
if side == 'L' #left eigenvectors only
3742-
return select, VL[:,1:m[]]
3744+
return select, reshape(resize!(vec(VL), VLn * m[]), VLn, m[])
37433745
elseif side == 'R' #right eigenvectors only
3744-
return select, VR[:,1:m[]]
3746+
return select, reshape(resize!(vec(VR), VRn * m[]), VRn, m[])
37453747
else #side == 'B' #both eigenvectors
3746-
return select, VL[:,1:m[]], VR[:,1:m[]]
3748+
return select, reshape(resize!(vec(VL), VLn * m[]), VLn, m[]), reshape(resize!(vec(VR), VRn * m[]), VRn, m[])
37473749
end
37483750
else #compute all eigenvectors
37493751
if side == 'L' #left eigenvectors only
3750-
return VL[:,1:m[]]
3752+
return reshape(resize!(vec(VL), VLn * m[]), VLn, m[])
37513753
elseif side == 'R' #right eigenvectors only
3752-
return VR[:,1:m[]]
3754+
return reshape(resize!(vec(VR), VRn * m[]), VRn, m[])
37533755
else #side == 'B' #both eigenvectors
3754-
return VL[:,1:m[]], VR[:,1:m[]]
3756+
return reshape(resize!(vec(VL), VLn * m[]), VLn, m[]), reshape(resize!(vec(VR), VRn * m[]), VRn, m[])
37553757
end
37563758
end
37573759
end
@@ -3873,22 +3875,24 @@ for (trcon, trevc, trrfs, elty, relty) in
38733875
work, rwork, info, 1, 1)
38743876
chklapackerror(info[])
38753877

3878+
VLn = size(VL, 1)
3879+
VRn = size(VR, 1)
38763880
#Decide what exactly to return
38773881
if howmny == 'S' #compute selected eigenvectors
38783882
if side == 'L' #left eigenvectors only
3879-
return select, VL[:,1:m[]]
3883+
return select, reshape(resize!(vec(VL), VLn * m[]), VLn, m[])
38803884
elseif side == 'R' #right eigenvectors only
3881-
return select, VR[:,1:m[]]
3882-
else #side=='B' #both eigenvectors
3883-
return select, VL[:,1:m[]], VR[:,1:m[]]
3885+
return select, reshape(resize!(vec(VR), VRn * m[]), VRn, m[])
3886+
else #side == 'B' #both eigenvectors
3887+
return select, reshape(resize!(vec(VL), VLn * m[]), VLn, m[]), reshape(resize!(vec(VR), VRn * m[]), VRn, m[])
38843888
end
38853889
else #compute all eigenvectors
38863890
if side == 'L' #left eigenvectors only
3887-
return VL[:,1:m[]]
3891+
return reshape(resize!(vec(VL), VLn * m[]), VLn, m[])
38883892
elseif side == 'R' #right eigenvectors only
3889-
return VR[:,1:m[]]
3890-
else #side=='B' #both eigenvectors
3891-
return VL[:,1:m[]], VR[:,1:m[]]
3893+
return reshape(resize!(vec(VR), VRn * m[]), VRn, m[])
3894+
else #side == 'B' #both eigenvectors
3895+
return reshape(resize!(vec(VL), VLn * m[]), VLn, m[]), reshape(resize!(vec(VR), VRn * m[]), VRn, m[])
38923896
end
38933897
end
38943898
end
@@ -4033,7 +4037,7 @@ for (stev, stebz, stegr, stein, elty) in
40334037
w, iblock, isplit, work,
40344038
iwork, info, 1, 1)
40354039
chklapackerror(info[])
4036-
w[1:m[]], iblock[1:m[]], isplit[1:nsplit[1]]
4040+
resize!(w, m[]), resize!(iblock, m[]), resize!(isplit, nsplit[1])
40374041
end
40384042

40394043
function stegr!(jobz::AbstractChar, range::AbstractChar, dv::AbstractVector{$elty}, ev::AbstractVector{$elty}, vl::Real, vu::Real, il::Integer, iu::Integer)
@@ -4056,8 +4060,9 @@ for (stev, stebz, stegr, stein, elty) in
40564060
m = Ref{BlasInt}()
40574061
w = similar(dv, $elty, n)
40584062
ldz = jobz == 'N' ? 1 : n
4059-
Z = similar(dv, $elty, ldz, range == 'I' ? iu-il+1 : n)
4060-
isuppz = similar(dv, BlasInt, 2*size(Z, 2))
4063+
Zn = range == 'I' ? iu-il+1 : n
4064+
Z = similar(dv, $elty, ldz * Zn)
4065+
isuppz = similar(dv, BlasInt, 2 * Zn)
40614066
work = Vector{$elty}(undef, 1)
40624067
lwork = BlasInt(-1)
40634068
iwork = Vector{BlasInt}(undef, 1)
@@ -4085,7 +4090,7 @@ for (stev, stebz, stegr, stein, elty) in
40854090
resize!(iwork, liwork)
40864091
end
40874092
end
4088-
m[] == length(w) ? w : w[1:m[]], m[] == size(Z, 2) ? Z : Z[:,1:m[]]
4093+
return resize!(w, m[]), reshape(resize!(Z, ldz * m[]), ldz, m[])
40894094
end
40904095

40914096
function stein!(dv::AbstractVector{$elty}, ev_in::AbstractVector{$elty}, w_in::AbstractVector{$elty}, iblock_in::AbstractVector{BlasInt}, isplit_in::AbstractVector{BlasInt})

test/lapack.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,11 @@ end
354354
@test_throws DimensionMismatch LAPACK.ormqr!('L','N',A,temp,B)
355355
@test_throws ArgumentError LAPACK.ormqr!('X','N',A,temp,B)
356356
@test_throws ArgumentError LAPACK.ormqr!('L','X',A,temp,B)
357+
358+
A = rand(elty,10,11)
359+
A,tau = LAPACK.geqrf!(A)
360+
B = copy(A)
361+
@test LAPACK.orgqr!(B,tau) LAPACK.ormqr!('R','N',A,tau,Matrix{elty}(I, 10, 10))
357362

358363
A = rand(elty,10,10)
359364
A,tau = LAPACK.geqlf!(A)
@@ -372,6 +377,11 @@ end
372377
@test_throws ArgumentError LAPACK.ormql!('X','N',A,temp,B)
373378
@test_throws ArgumentError LAPACK.ormql!('L','X',A,temp,B)
374379

380+
A = rand(elty,10,11)
381+
A,tau = LAPACK.geqlf!(A)
382+
B = copy(A)
383+
@test LAPACK.orgql!(B,tau) LAPACK.ormql!('R','N',A,tau,Matrix{elty}(I, 10, 10))
384+
375385
A = rand(elty,10,10)
376386
A,tau = LAPACK.gerqf!(A)
377387
@test_throws DimensionMismatch LAPACK.orgrq!(A,tau,11)
@@ -733,6 +743,11 @@ end
733743
select,Vln,Vrn = LAPACK.trevc!('B','S',select,copy(T))
734744
@test Vrn v
735745
@test Vln Vl
746+
Vl = LAPACK.trevc!('L','A',select,copy(T))
747+
Vr = LAPACK.trevc!('R','A',select,copy(T))
748+
Vla, Vra = LAPACK.trevc!('B','A',select,copy(T))
749+
@test Vr Vra
750+
@test Vl Vla
736751
@test_throws ArgumentError LAPACK.trevc!('V','S',select,T)
737752
@test_throws ArgumentError LAPACK.trevc!('R','X',select,T)
738753
temp1010 = rand(elty,10,10)

0 commit comments

Comments
 (0)