Skip to content

Commit 762f3e0

Browse files
Adding tests for inplace qr of views
1 parent 8941d0a commit 762f3e0

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

test/cusolver/dense.jl

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,20 @@ l_range = (1:l) .+ (l_sub_start -1)
483483
@test Array(d_q) Array(q)
484484
@test Array(d_r) Array(r)
485485

486+
487+
A = rand(elty, n, m)
488+
d_A = CuArray(A)
489+
d_q, d_r = qr!(d_A)
490+
q, r = qr!(A)
491+
@test collect(d_q) Array(q)
492+
@test collect(d_r) Array(r)
493+
A_view = view(A, m_subrange, n_subrange)
494+
d_A_view = view(d_A, m_subrange, n_subrange)
495+
d_q, d_r = qr!(d_A_view)
496+
q, r = qr!(A_view)
497+
@test collect(d_q) Array(q)
498+
@test collect(d_r) Array(r)
499+
486500
A = rand(elty, n) # A and B are vectors
487501
d_A = CuArray(A)
488502
M = qr(A)
@@ -503,6 +517,26 @@ l_range = (1:l) .+ (l_sub_start -1)
503517
d_B = view(d_B_large, n_range)
504518
@test collect(d_M \ d_B) M \ B
505519

520+
A = rand(elty, n) # A and B are vectors
521+
d_A = CuArray(A)
522+
M = qr!(A)
523+
d_M = qr!(d_A)
524+
B = rand(elty, n)
525+
d_B = CuArray(B)
526+
@test collect(d_M \ d_B) M \ B
527+
A_view = view(A, n_subrange)
528+
d_A_view = view(d_A, n_subrange)
529+
M_view = qr!(A_view)
530+
d_M_view = qr!(d_A_view)
531+
B_view = view(B, n_subrange)
532+
d_B_view = view(d_B, n_subrange)
533+
@test collect(d_M_view \ d_B_view) M_view \ B_view
534+
B_large = rand(elty, n_large)
535+
B = view(B_large, n_range)
536+
d_B_large = CuArray(B_large)
537+
d_B = view(d_B_large, n_range)
538+
@test collect(d_M \ d_B) M \ B
539+
506540
A = rand(elty, m, n) # A is a matrix and B,C is a vector
507541
d_A = CuArray(A)
508542
M = qr(A)
@@ -555,6 +589,58 @@ l_range = (1:l) .+ (l_sub_start -1)
555589
@test collect(d_C' * d_M.R) (C' * M.R)
556590
@test collect(d_C' * d_M.R') (C' * M.R')
557591

592+
A = rand(elty, m, n) # A is a matrix and B,C is a vector
593+
d_A = CuArray(A)
594+
M = qr!(A)
595+
d_M = qr!(d_A)
596+
B = rand(elty, m)
597+
d_B = CuArray(B)
598+
C = rand(elty, n)
599+
d_C = CuArray(C)
600+
@test collect(d_M \ d_B) M \ B
601+
@test collect(d_M.Q * d_B) (M.Q * B)
602+
@test collect(d_M.Q' * d_B) (M.Q' * B)
603+
@test collect(d_B' * d_M.Q) (B' * M.Q)
604+
@test collect(d_B' * d_M.Q') (B' * M.Q')
605+
@test collect(d_M.R * d_C) (M.R * C)
606+
@test collect(d_M.R' * d_C) (M.R' * C)
607+
@test collect(d_C' * d_M.R) (C' * M.R)
608+
@test collect(d_C' * d_M.R') (C' * M.R')
609+
A_view = view(A, m_subrange, n_subrange)
610+
d_A_view = view(d_A, m_subrange, n_subrange)
611+
M_view = qr!(A_view)
612+
d_M_view = qr!(d_A_view)
613+
B_view = view(B, m_subrange)
614+
d_B_view = view(d_B, m_subrange)
615+
C_view = view(C, n_subrange)
616+
d_C_view = view(d_C, n_subrange)
617+
@test collect(d_M_view \ d_B_view) M_view \ B_view
618+
@test collect(d_M_view.Q * d_B_view) (M_view.Q * B_view)
619+
@test collect(d_M_view.Q' * d_B_view) (M_view.Q' * B_view)
620+
@test collect(d_B_view' * d_M_view.Q) (B_view' * M_view.Q)
621+
@test collect(d_B_view' * d_M_view.Q') (B_view' * M_view.Q')
622+
@test collect(d_M_view.R * d_C_view) (M_view.R * C_view)
623+
@test collect(d_M_view.R' * d_C_view) (M_view.R' * C_view)
624+
@test collect(d_C_view' * d_M_view.R) (C_view' * M_view.R)
625+
@test collect(d_C_view' * d_M_view.R') (C_view' * M_view.R')
626+
B_large = rand(elty, m_large)
627+
B = view(B_large, m_range)
628+
d_B_large = CuArray(B_large)
629+
d_B = view(d_B_large, m_range)
630+
C_large = rand(elty, n_large)
631+
C = view(C_large, n_range)
632+
d_C_large = CuArray(C_large)
633+
d_C = view(d_C_large, n_range)
634+
@test collect(d_M \ d_B) M \ B
635+
@test collect(d_M.Q * d_B) (M.Q * B)
636+
@test collect(d_M.Q' * d_B) (M.Q' * B)
637+
@test collect(d_B' * d_M.Q) (B' * M.Q)
638+
@test collect(d_B' * d_M.Q') (B' * M.Q')
639+
@test collect(d_M.R * d_C) (M.R * C)
640+
@test collect(d_M.R' * d_C) (M.R' * C)
641+
@test collect(d_C' * d_M.R) (C' * M.R)
642+
@test collect(d_C' * d_M.R') (C' * M.R')
643+
558644
A = rand(elty, m, n) # A and B,C are matrices
559645
d_A = CuArray(A)
560646
M = qr(A)
@@ -607,6 +693,58 @@ l_range = (1:l) .+ (l_sub_start -1)
607693
@test collect(d_C' * d_M.R) (C' * M.R)
608694
@test collect(d_C' * d_M.R') (C' * M.R')
609695

696+
A = rand(elty, m, n) # A and B,C are matrices
697+
d_A = CuArray(A)
698+
M = qr!(A)
699+
d_M = qr!(d_A)
700+
B = rand(elty, m, l) #different second dimension to verify whether dimensions agree
701+
d_B = CuArray(B)
702+
C = rand(elty, n, l) #different second dimension to verify whether dimensions agree
703+
d_C = CuArray(C)
704+
@test collect(d_M \ d_B) (M \ B)
705+
@test collect(d_M.Q * d_B) (M.Q * B)
706+
@test collect(d_M.Q' * d_B) (M.Q' * B)
707+
@test collect(d_B' * d_M.Q) (B' * M.Q)
708+
@test collect(d_B' * d_M.Q') (B' * M.Q')
709+
@test collect(d_M.R * d_C) (M.R * C)
710+
@test collect(d_M.R' * d_C) (M.R' * C)
711+
@test collect(d_C' * d_M.R) (C' * M.R)
712+
@test collect(d_C' * d_M.R') (C' * M.R')
713+
A_view = view(A, m_subrange, n_subrange)
714+
d_A_view = view(d_A, m_subrange, n_subrange)
715+
M_view = qr!(A_view)
716+
d_M_view = qr!(d_A_view)
717+
B_view = view(B, m_subrange, l_subrange)
718+
d_B_view = view(d_B, m_subrange, l_subrange)
719+
C_view = view(C, n_subrange, l_subrange)
720+
d_C_view = view(d_C, n_subrange, l_subrange)
721+
@test collect(d_M_view \ d_B_view) M_view \ B_view
722+
@test collect(d_M_view.Q * d_B_view) (M_view.Q * B_view)
723+
@test collect(d_M_view.Q' * d_B_view) (M_view.Q' * B_view)
724+
@test collect(d_B_view' * d_M_view.Q) (B_view' * M_view.Q)
725+
@test collect(d_B_view' * d_M_view.Q') (B_view' * M_view.Q')
726+
@test collect(d_M_view.R * d_C_view) (M_view.R * C_view)
727+
@test collect(d_M_view.R' * d_C_view) (M_view.R' * C_view)
728+
@test collect(d_C_view' * d_M_view.R) (C_view' * M_view.R)
729+
@test collect(d_C_view' * d_M_view.R') (C_view' * M_view.R')
730+
B_large = rand(elty, m_large, l_large)
731+
B = view(B_large, m_range, l_range)
732+
d_B_large = CuArray(B_large)
733+
d_B = view(d_B_large, m_range, l_range)
734+
C_large = rand(elty, n_large, l_large)
735+
C = view(C_large, n_range, l_range)
736+
d_C_large = CuArray(C_large)
737+
d_C = view(d_C_large, n_range, l_range)
738+
@test collect(d_M \ d_B) M \ B
739+
@test collect(d_M.Q * d_B) (M.Q * B)
740+
@test collect(d_M.Q' * d_B) (M.Q' * B)
741+
@test collect(d_B' * d_M.Q) (B' * M.Q)
742+
@test collect(d_B' * d_M.Q') (B' * M.Q')
743+
@test collect(d_M.R * d_C) (M.R * C)
744+
@test collect(d_M.R' * d_C) (M.R' * C)
745+
@test collect(d_C' * d_M.R) (C' * M.R)
746+
@test collect(d_C' * d_M.R') (C' * M.R')
747+
610748
end
611749

612750
@testset "potrsBatched!" begin

0 commit comments

Comments
 (0)