Skip to content

Commit 6b0b46e

Browse files
committed
Added tests and confirmed build
1 parent 2a82186 commit 6b0b46e

File tree

1 file changed

+156
-36
lines changed

1 file changed

+156
-36
lines changed

src/tests/linalg/test_linalg.f90

Lines changed: 156 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,16 @@ program test_linalg
5959
!
6060
! outer product
6161
!
62-
!commented until outer_product compiles
63-
!call test_outer_product_rsp
64-
!call test_outer_product_rsp_k
65-
!call test_outer_product_rdp
66-
!call test_outer_product_rqp
67-
68-
!call test_outer_product_csp
69-
!call test_outer_product_cdp
70-
!call test_outer_product_cqp
71-
72-
!call test_outer_product_int8
73-
!call test_outer_product_int16
62+
call test_outer_product_rsp
63+
call test_outer_product_rdp
64+
call test_outer_product_rqp
65+
66+
call test_outer_product_csp
67+
call test_outer_product_cdp
68+
call test_outer_product_cqp
69+
70+
call test_outer_product_int8
71+
call test_outer_product_int16
7472
!call test_outer_product_int32
7573
!call test_outer_product_int64
7674

@@ -93,7 +91,7 @@ subroutine test_eye
9391
cye = eye(7)
9492
call check(abs(trace(cye) - cmplx(7.0_sp,0.0_sp,kind=sp)) < sptol, &
9593
msg="abs(trace(cye) - cmplx(7.0_sp,0.0_sp,kind=sp)) < sptol failed.",warn=warn)
96-
end subroutine
94+
end subroutine test_eye
9795

9896
subroutine test_diag_rsp
9997
integer, parameter :: n = 3
@@ -108,7 +106,7 @@ subroutine test_diag_rsp
108106

109107
call check(all(diag(3*a) == 3*v), &
110108
msg="all(diag(3*a) == 3*v) failed.",warn=warn)
111-
end subroutine
109+
end subroutine test_diag_rsp
112110

113111
subroutine test_diag_rsp_k
114112
integer, parameter :: n = 4
@@ -136,7 +134,7 @@ subroutine test_diag_rsp_k
136134
end do
137135
call check(size(diag(a,n+1)) == 0, &
138136
msg="size(diag(a,n+1)) == 0 failed.",warn=warn)
139-
end subroutine
137+
end subroutine test_diag_rsp_k
140138

141139
subroutine test_diag_rdp
142140
integer, parameter :: n = 3
@@ -151,7 +149,7 @@ subroutine test_diag_rdp
151149

152150
call check(all(diag(3*a) == 3*v), &
153151
msg="all(diag(3*a) == 3*v) failed.",warn=warn)
154-
end subroutine
152+
end subroutine test_diag_rdp
155153

156154
subroutine test_diag_rqp
157155
integer, parameter :: n = 3
@@ -166,7 +164,7 @@ subroutine test_diag_rqp
166164

167165
call check(all(diag(3*a) == 3*v), &
168166
msg="all(diag(3*a) == 3*v) failed.", warn=warn)
169-
end subroutine
167+
end subroutine test_diag_rqp
170168

171169
subroutine test_diag_csp
172170
integer, parameter :: n = 3
@@ -183,7 +181,7 @@ subroutine test_diag_csp
183181
msg="all(abs(real(diag(a)) - [(i,i=1,n)]) < sptol)", warn=warn)
184182
call check(all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol), &
185183
msg="all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol)", warn=warn)
186-
end subroutine
184+
end subroutine test_diag_csp
187185

188186
subroutine test_diag_cdp
189187
integer, parameter :: n = 3
@@ -193,7 +191,7 @@ subroutine test_diag_cdp
193191
a = diag([i_],-2) + diag([i_],2)
194192
call check(a(3,1) == i_ .and. a(1,3) == i_, &
195193
msg="a(3,1) == i_ .and. a(1,3) == i_ failed.",warn=warn)
196-
end subroutine
194+
end subroutine test_diag_cdp
197195

198196
subroutine test_diag_cqp
199197
integer, parameter :: n = 3
@@ -203,7 +201,7 @@ subroutine test_diag_cqp
203201
a = diag([i_,i_],-1) + diag([i_,i_],1)
204202
call check(all(diag(a,-1) == i_) .and. all(diag(a,1) == i_), &
205203
msg="all(diag(a,-1) == i_) .and. all(diag(a,1) == i_) failed.",warn=warn)
206-
end subroutine
204+
end subroutine test_diag_cqp
207205

208206
subroutine test_diag_int8
209207
integer, parameter :: n = 3
@@ -217,7 +215,7 @@ subroutine test_diag_int8
217215
msg="all(diag(a) == pack(a,mask)) failed.", warn=warn)
218216
call check(all(diag(diag(a)) == merge(a,0_int8,mask)), &
219217
msg="all(diag(diag(a)) == merge(a,0_int8,mask)) failed.", warn=warn)
220-
end subroutine
218+
end subroutine test_diag_int8
221219
subroutine test_diag_int16
222220
integer, parameter :: n = 4
223221
integer(int16), allocatable :: a(:,:)
@@ -230,7 +228,7 @@ subroutine test_diag_int16
230228
msg="all(diag(a) == pack(a,mask))", warn=warn)
231229
call check(all(diag(diag(a)) == merge(a,0_int16,mask)), &
232230
msg="all(diag(diag(a)) == merge(a,0_int16,mask)) failed.", warn=warn)
233-
end subroutine
231+
end subroutine test_diag_int16
234232
subroutine test_diag_int32
235233
integer, parameter :: n = 3
236234
integer(int32) :: a(n,n)
@@ -244,7 +242,7 @@ subroutine test_diag_int32
244242
msg="all(diag([1,1],-1) == a) failed.", warn=warn)
245243
call check(all(diag([1,1],1) == transpose(a)), &
246244
msg="all(diag([1,1],1) == transpose(a)) failed.", warn=warn)
247-
end subroutine
245+
end subroutine test_diag_int32
248246
subroutine test_diag_int64
249247
integer, parameter :: n = 4
250248
integer(int64) :: a(n,n), c(0:2*n-1)
@@ -275,7 +273,7 @@ subroutine test_diag_int64
275273
end do
276274
call check(all(diag(a,-2) == diag(a,2)), &
277275
msg="all(diag(a,-2) == diag(a,2))", warn=warn)
278-
end subroutine
276+
end subroutine test_diag_int64
279277

280278

281279

@@ -288,7 +286,7 @@ subroutine test_trace_rsp
288286
a = reshape([(i,i=1,n**2)],[n,n])
289287
call check(abs(trace(a) - sum(diag(a))) < sptol, &
290288
msg="abs(trace(a) - sum(diag(a))) < sptol failed.",warn=warn)
291-
end subroutine
289+
end subroutine test_trace_rsp
292290

293291
subroutine test_trace_rsp_nonsquare
294292
integer, parameter :: n = 4
@@ -305,7 +303,7 @@ subroutine test_trace_rsp_nonsquare
305303

306304
call check(abs(trace(a) - ans) < sptol, &
307305
msg="abs(trace(a) - ans) < sptol failed.",warn=warn)
308-
end subroutine
306+
end subroutine test_trace_rsp_nonsquare
309307

310308
subroutine test_trace_rdp
311309
integer, parameter :: n = 4
@@ -315,7 +313,7 @@ subroutine test_trace_rdp
315313
a = reshape([(i,i=1,n**2)],[n,n])
316314
call check(abs(trace(a) - sum(diag(a))) < dptol, &
317315
msg="abs(trace(a) - sum(diag(a))) < dptol failed.",warn=warn)
318-
end subroutine
316+
end subroutine test_trace_rdp
319317

320318
subroutine test_trace_rdp_nonsquare
321319
integer, parameter :: n = 4
@@ -332,7 +330,7 @@ subroutine test_trace_rdp_nonsquare
332330

333331
call check(abs(trace(a) - ans) < dptol, &
334332
msg="abs(trace(a) - ans) < dptol failed.",warn=warn)
335-
end subroutine
333+
end subroutine test_trace_rdp_nonsquare
336334

337335
subroutine test_trace_rqp
338336
integer, parameter :: n = 3
@@ -342,7 +340,7 @@ subroutine test_trace_rqp
342340
a = reshape([(i,i=1,n**2)],[n,n])
343341
call check(abs(trace(a) - sum(diag(a))) < qptol, &
344342
msg="abs(trace(a) - sum(diag(a))) < qptol failed.",warn=warn)
345-
end subroutine
343+
end subroutine test_trace_rqp
346344

347345

348346
subroutine test_trace_csp
@@ -363,7 +361,7 @@ subroutine test_trace_csp
363361
! tr(A + B) = tr(A) + tr(B)
364362
call check(abs(trace(a+b) - (trace(a) + trace(b))) < sptol, &
365363
msg="abs(trace(a+b) - (trace(a) + trace(b))) < sptol failed.",warn=warn)
366-
end subroutine
364+
end subroutine test_trace_csp
367365

368366
subroutine test_trace_cdp
369367
integer, parameter :: n = 3
@@ -377,7 +375,7 @@ subroutine test_trace_cdp
377375

378376
call check(abs(trace(a) - ans) < dptol, &
379377
msg="abs(trace(a) - ans) < dptol failed.",warn=warn)
380-
end subroutine
378+
end subroutine test_trace_cdp
381379

382380
subroutine test_trace_cqp
383381
integer, parameter :: n = 3
@@ -387,7 +385,7 @@ subroutine test_trace_cqp
387385
a = 3*eye(n) + 4*eye(n)*i_ ! pythagorean triple
388386
call check(abs(trace(a)) - 3*5.0_qp < qptol, &
389387
msg="abs(trace(a)) - 3*5.0_qp < qptol failed.",warn=warn)
390-
end subroutine
388+
end subroutine test_trace_cqp
391389

392390

393391
subroutine test_trace_int8
@@ -398,7 +396,7 @@ subroutine test_trace_int8
398396
a = reshape([(i**2,i=1,n**2)],[n,n])
399397
call check(trace(a) == (1 + 25 + 81), &
400398
msg="trace(a) == (1 + 25 + 81) failed.",warn=warn)
401-
end subroutine
399+
end subroutine test_trace_int8
402400

403401
subroutine test_trace_int16
404402
integer, parameter :: n = 3
@@ -408,7 +406,7 @@ subroutine test_trace_int16
408406
a = reshape([(i**3,i=1,n**2)],[n,n])
409407
call check(trace(a) == (1 + 125 + 729), &
410408
msg="trace(a) == (1 + 125 + 729) failed.",warn=warn)
411-
end subroutine
409+
end subroutine test_trace_int16
412410

413411
subroutine test_trace_int32
414412
integer, parameter :: n = 3
@@ -418,7 +416,7 @@ subroutine test_trace_int32
418416
a = reshape([(i**4,i=1,n**2)],[n,n])
419417
call check(trace(a) == (1 + 625 + 6561), &
420418
msg="trace(a) == (1 + 625 + 6561) failed.",warn=warn)
421-
end subroutine
419+
end subroutine test_trace_int32
422420

423421
subroutine test_trace_int64
424422
integer, parameter :: n = 5
@@ -442,7 +440,129 @@ subroutine test_trace_int64
442440
call check(trace(h) == sum(c(0:nd:2)), &
443441
msg="trace(h) == sum(c(0:nd:2)) failed.",warn=warn)
444442

445-
end subroutine
443+
end subroutine test_trace_int64
444+
445+
446+
subroutine test_outer_product_rsp
447+
integer, parameter :: n = 2
448+
real(sp) :: u(n), v(n), expected(n,n), diff(n,n)
449+
write(*,*) "test_outer_product_rsp"
450+
u = [1.,2.]
451+
v = [1.,3.]
452+
expected = reshape([1.,2.,3.,6.],[n,n])
453+
diff = expected - outer_product(u,v)
454+
call check(all(abs(diff) < sptol), &
455+
msg="all(abs(diff) < sptol) failed.",warn=warn)
456+
end subroutine test_outer_product_rsp
457+
458+
subroutine test_outer_product_rdp
459+
integer, parameter :: n = 2
460+
real(dp) :: u(n), v(n), expected(n,n), diff(n,n)
461+
write(*,*) "test_outer_product_rdp"
462+
u = [1.,2.]
463+
v = [1.,3.]
464+
expected = reshape([1.,2.,3.,6.],[n,n])
465+
diff = expected - outer_product(u,v)
466+
call check(all(abs(diff) < dptol), &
467+
msg="all(abs(diff) < dptol) failed.",warn=warn)
468+
end subroutine test_outer_product_rdp
469+
470+
subroutine test_outer_product_rqp
471+
integer, parameter :: n = 2
472+
real(qp) :: u(n), v(n), expected(n,n), diff(n,n)
473+
write(*,*) "test_outer_product_rqp"
474+
u = [1.,2.]
475+
v = [1.,3.]
476+
expected = reshape([1.,2.,3.,6.],[n,n])
477+
diff = expected - outer_product(u,v)
478+
call check(all(abs(diff) < qptol), &
479+
msg="all(abs(diff) < qptol) failed.",warn=warn)
480+
end subroutine test_outer_product_rqp
481+
482+
subroutine test_outer_product_csp
483+
integer, parameter :: n = 2
484+
complex(sp) :: u(n), v(n), expected(n,n), diff(n,n)
485+
write(*,*) "test_outer_product_csp"
486+
u = [cmplx(1.,1.),cmplx(2.,0.)]
487+
v = [cmplx(1.,0.),cmplx(3.,1.)]
488+
expected = reshape([cmplx(1.,1.),cmplx(2.,0.),cmplx(2.,4.),cmplx(6.,2.)],[n,n])
489+
diff = expected - outer_product(u,v)
490+
call check(all(abs(diff) < sptol), &
491+
msg="all(abs(diff) < sptol) failed.",warn=warn)
492+
end subroutine test_outer_product_csp
493+
494+
subroutine test_outer_product_cdp
495+
integer, parameter :: n = 2
496+
complex(dp) :: u(n), v(n), expected(n,n), diff(n,n)
497+
write(*,*) "test_outer_product_cdp"
498+
u = [cmplx(1.,1.),cmplx(2.,0.)]
499+
v = [cmplx(1.,0.),cmplx(3.,1.)]
500+
expected = reshape([cmplx(1.,1.),cmplx(2.,0.),cmplx(2.,4.),cmplx(6.,2.)],[n,n])
501+
diff = expected - outer_product(u,v)
502+
call check(all(abs(diff) < dptol), &
503+
msg="all(abs(diff) < dptol) failed.",warn=warn)
504+
end subroutine test_outer_product_cdp
505+
506+
subroutine test_outer_product_cqp
507+
integer, parameter :: n = 2
508+
complex(qp) :: u(n), v(n), expected(n,n), diff(n,n)
509+
write(*,*) "test_outer_product_cqp"
510+
u = [cmplx(1.,1.),cmplx(2.,0.)]
511+
v = [cmplx(1.,0.),cmplx(3.,1.)]
512+
expected = reshape([cmplx(1.,1.),cmplx(2.,0.),cmplx(2.,4.),cmplx(6.,2.)],[n,n])
513+
diff = expected - outer_product(u,v)
514+
call check(all(abs(diff) < qptol), &
515+
msg="all(abs(diff) < qptol) failed.",warn=warn)
516+
end subroutine test_outer_product_cqp
517+
518+
subroutine test_outer_product_int8
519+
integer, parameter :: n = 2
520+
integer(int8) :: u(n), v(n), expected(n,n), diff(n,n)
521+
write(*,*) "test_outer_product_int8"
522+
u = [1,2]
523+
v = [1,3]
524+
expected = reshape([1,2,3,6],[n,n])
525+
diff = expected - outer_product(u,v)
526+
call check(all(abs(diff) == 0), &
527+
msg="all(abs(diff) == 0) failed.",warn=warn)
528+
end subroutine test_outer_product_int8
529+
530+
subroutine test_outer_product_int16
531+
integer, parameter :: n = 2
532+
integer(int16) :: u(n), v(n), expected(n,n), diff(n,n)
533+
write(*,*) "test_outer_product_int16"
534+
u = [1,2]
535+
v = [1,3]
536+
expected = reshape([1,2,3,6],[n,n])
537+
diff = expected - outer_product(u,v)
538+
call check(all(abs(diff) == 0), &
539+
msg="all(abs(diff) == 0) failed.",warn=warn)
540+
end subroutine test_outer_product_int16
541+
542+
subroutine test_outer_product_int32
543+
integer, parameter :: n = 2
544+
integer(int32) :: u(n), v(n), expected(n,n), diff(n,n)
545+
write(*,*) "test_outer_product_int32"
546+
u = [1,2]
547+
v = [1,3]
548+
expected = reshape([1,2,3,6],[n,n])
549+
diff = expected - outer_product(u,v)
550+
call check(all(abs(diff) == 0), &
551+
msg="all(abs(diff) == 0) failed.",warn=warn)
552+
end subroutine test_outer_product_int32
553+
554+
subroutine test_outer_product_int64
555+
integer, parameter :: n = 2
556+
integer(int64) :: u(n), v(n), expected(n,n), diff(n,n)
557+
write(*,*) "test_outer_product_int64"
558+
u = [1,2]
559+
v = [1,3]
560+
expected = reshape([1,2,3,6],[n,n])
561+
diff = expected - outer_product(u,v)
562+
call check(all(abs(diff) == 0), &
563+
msg="all(abs(diff) == 0) failed.",warn=warn)
564+
end subroutine test_outer_product_int64
565+
446566

447567
pure recursive function catalan_number(n) result(value)
448568
integer, intent(in) :: n

0 commit comments

Comments
 (0)