@@ -59,18 +59,16 @@ program test_linalg
59
59
!
60
60
! outer product
61
61
!
62
- ! commented until outer_product compiles
63
- ! call test_outer_product_rsp
64
- ! call test_outer_product_rsp_k
65
- ! call test_outer_product_rdp
66
- ! call test_outer_product_rqp
67
-
68
- ! call test_outer_product_csp
69
- ! call test_outer_product_cdp
70
- ! call test_outer_product_cqp
71
-
72
- ! call test_outer_product_int8
73
- ! call test_outer_product_int16
62
+ call test_outer_product_rsp
63
+ call test_outer_product_rdp
64
+ call test_outer_product_rqp
65
+
66
+ call test_outer_product_csp
67
+ call test_outer_product_cdp
68
+ call test_outer_product_cqp
69
+
70
+ call test_outer_product_int8
71
+ call test_outer_product_int16
74
72
! call test_outer_product_int32
75
73
! call test_outer_product_int64
76
74
@@ -93,7 +91,7 @@ subroutine test_eye
93
91
cye = eye(7 )
94
92
call check(abs (trace(cye) - cmplx (7.0_sp ,0.0_sp ,kind= sp)) < sptol, &
95
93
msg= " abs(trace(cye) - cmplx(7.0_sp,0.0_sp,kind=sp)) < sptol failed." ,warn= warn)
96
- end subroutine
94
+ end subroutine test_eye
97
95
98
96
subroutine test_diag_rsp
99
97
integer , parameter :: n = 3
@@ -108,7 +106,7 @@ subroutine test_diag_rsp
108
106
109
107
call check(all (diag(3 * a) == 3 * v), &
110
108
msg= " all(diag(3*a) == 3*v) failed." ,warn= warn)
111
- end subroutine
109
+ end subroutine test_diag_rsp
112
110
113
111
subroutine test_diag_rsp_k
114
112
integer , parameter :: n = 4
@@ -136,7 +134,7 @@ subroutine test_diag_rsp_k
136
134
end do
137
135
call check(size (diag(a,n+1 )) == 0 , &
138
136
msg= " size(diag(a,n+1)) == 0 failed." ,warn= warn)
139
- end subroutine
137
+ end subroutine test_diag_rsp_k
140
138
141
139
subroutine test_diag_rdp
142
140
integer , parameter :: n = 3
@@ -151,7 +149,7 @@ subroutine test_diag_rdp
151
149
152
150
call check(all (diag(3 * a) == 3 * v), &
153
151
msg= " all(diag(3*a) == 3*v) failed." ,warn= warn)
154
- end subroutine
152
+ end subroutine test_diag_rdp
155
153
156
154
subroutine test_diag_rqp
157
155
integer , parameter :: n = 3
@@ -166,7 +164,7 @@ subroutine test_diag_rqp
166
164
167
165
call check(all (diag(3 * a) == 3 * v), &
168
166
msg= " all(diag(3*a) == 3*v) failed." , warn= warn)
169
- end subroutine
167
+ end subroutine test_diag_rqp
170
168
171
169
subroutine test_diag_csp
172
170
integer , parameter :: n = 3
@@ -183,7 +181,7 @@ subroutine test_diag_csp
183
181
msg= " all(abs(real(diag(a)) - [(i,i=1,n)]) < sptol)" , warn= warn)
184
182
call check(all (abs (aimag (diag(a)) - [(1 ,i= 1 ,n)]) < sptol), &
185
183
msg= " all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol)" , warn= warn)
186
- end subroutine
184
+ end subroutine test_diag_csp
187
185
188
186
subroutine test_diag_cdp
189
187
integer , parameter :: n = 3
@@ -193,7 +191,7 @@ subroutine test_diag_cdp
193
191
a = diag([i_],- 2 ) + diag([i_],2 )
194
192
call check(a(3 ,1 ) == i_ .and. a(1 ,3 ) == i_, &
195
193
msg= " a(3,1) == i_ .and. a(1,3) == i_ failed." ,warn= warn)
196
- end subroutine
194
+ end subroutine test_diag_cdp
197
195
198
196
subroutine test_diag_cqp
199
197
integer , parameter :: n = 3
@@ -203,7 +201,7 @@ subroutine test_diag_cqp
203
201
a = diag([i_,i_],- 1 ) + diag([i_,i_],1 )
204
202
call check(all (diag(a,- 1 ) == i_) .and. all (diag(a,1 ) == i_), &
205
203
msg= " all(diag(a,-1) == i_) .and. all(diag(a,1) == i_) failed." ,warn= warn)
206
- end subroutine
204
+ end subroutine test_diag_cqp
207
205
208
206
subroutine test_diag_int8
209
207
integer , parameter :: n = 3
@@ -217,7 +215,7 @@ subroutine test_diag_int8
217
215
msg= " all(diag(a) == pack(a,mask)) failed." , warn= warn)
218
216
call check(all (diag(diag(a)) == merge (a,0_int8 ,mask)), &
219
217
msg= " all(diag(diag(a)) == merge(a,0_int8,mask)) failed." , warn= warn)
220
- end subroutine
218
+ end subroutine test_diag_int8
221
219
subroutine test_diag_int16
222
220
integer , parameter :: n = 4
223
221
integer (int16), allocatable :: a(:,:)
@@ -230,7 +228,7 @@ subroutine test_diag_int16
230
228
msg= " all(diag(a) == pack(a,mask))" , warn= warn)
231
229
call check(all (diag(diag(a)) == merge (a,0_int16 ,mask)), &
232
230
msg= " all(diag(diag(a)) == merge(a,0_int16,mask)) failed." , warn= warn)
233
- end subroutine
231
+ end subroutine test_diag_int16
234
232
subroutine test_diag_int32
235
233
integer , parameter :: n = 3
236
234
integer (int32) :: a(n,n)
@@ -244,7 +242,7 @@ subroutine test_diag_int32
244
242
msg= " all(diag([1,1],-1) == a) failed." , warn= warn)
245
243
call check(all (diag([1 ,1 ],1 ) == transpose (a)), &
246
244
msg= " all(diag([1,1],1) == transpose(a)) failed." , warn= warn)
247
- end subroutine
245
+ end subroutine test_diag_int32
248
246
subroutine test_diag_int64
249
247
integer , parameter :: n = 4
250
248
integer (int64) :: a(n,n), c(0 :2 * n-1 )
@@ -275,7 +273,7 @@ subroutine test_diag_int64
275
273
end do
276
274
call check(all (diag(a,- 2 ) == diag(a,2 )), &
277
275
msg= " all(diag(a,-2) == diag(a,2))" , warn= warn)
278
- end subroutine
276
+ end subroutine test_diag_int64
279
277
280
278
281
279
@@ -288,7 +286,7 @@ subroutine test_trace_rsp
288
286
a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
289
287
call check(abs (trace(a) - sum (diag(a))) < sptol, &
290
288
msg= " abs(trace(a) - sum(diag(a))) < sptol failed." ,warn= warn)
291
- end subroutine
289
+ end subroutine test_trace_rsp
292
290
293
291
subroutine test_trace_rsp_nonsquare
294
292
integer , parameter :: n = 4
@@ -305,7 +303,7 @@ subroutine test_trace_rsp_nonsquare
305
303
306
304
call check(abs (trace(a) - ans) < sptol, &
307
305
msg= " abs(trace(a) - ans) < sptol failed." ,warn= warn)
308
- end subroutine
306
+ end subroutine test_trace_rsp_nonsquare
309
307
310
308
subroutine test_trace_rdp
311
309
integer , parameter :: n = 4
@@ -315,7 +313,7 @@ subroutine test_trace_rdp
315
313
a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
316
314
call check(abs (trace(a) - sum (diag(a))) < dptol, &
317
315
msg= " abs(trace(a) - sum(diag(a))) < dptol failed." ,warn= warn)
318
- end subroutine
316
+ end subroutine test_trace_rdp
319
317
320
318
subroutine test_trace_rdp_nonsquare
321
319
integer , parameter :: n = 4
@@ -332,7 +330,7 @@ subroutine test_trace_rdp_nonsquare
332
330
333
331
call check(abs (trace(a) - ans) < dptol, &
334
332
msg= " abs(trace(a) - ans) < dptol failed." ,warn= warn)
335
- end subroutine
333
+ end subroutine test_trace_rdp_nonsquare
336
334
337
335
subroutine test_trace_rqp
338
336
integer , parameter :: n = 3
@@ -342,7 +340,7 @@ subroutine test_trace_rqp
342
340
a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
343
341
call check(abs (trace(a) - sum (diag(a))) < qptol, &
344
342
msg= " abs(trace(a) - sum(diag(a))) < qptol failed." ,warn= warn)
345
- end subroutine
343
+ end subroutine test_trace_rqp
346
344
347
345
348
346
subroutine test_trace_csp
@@ -363,7 +361,7 @@ subroutine test_trace_csp
363
361
! tr(A + B) = tr(A) + tr(B)
364
362
call check(abs (trace(a+ b) - (trace(a) + trace(b))) < sptol, &
365
363
msg= " abs(trace(a+b) - (trace(a) + trace(b))) < sptol failed." ,warn= warn)
366
- end subroutine
364
+ end subroutine test_trace_csp
367
365
368
366
subroutine test_trace_cdp
369
367
integer , parameter :: n = 3
@@ -377,7 +375,7 @@ subroutine test_trace_cdp
377
375
378
376
call check(abs (trace(a) - ans) < dptol, &
379
377
msg= " abs(trace(a) - ans) < dptol failed." ,warn= warn)
380
- end subroutine
378
+ end subroutine test_trace_cdp
381
379
382
380
subroutine test_trace_cqp
383
381
integer , parameter :: n = 3
@@ -387,7 +385,7 @@ subroutine test_trace_cqp
387
385
a = 3 * eye(n) + 4 * eye(n)* i_ ! pythagorean triple
388
386
call check(abs (trace(a)) - 3 * 5.0_qp < qptol, &
389
387
msg= " abs(trace(a)) - 3*5.0_qp < qptol failed." ,warn= warn)
390
- end subroutine
388
+ end subroutine test_trace_cqp
391
389
392
390
393
391
subroutine test_trace_int8
@@ -398,7 +396,7 @@ subroutine test_trace_int8
398
396
a = reshape ([(i** 2 ,i= 1 ,n** 2 )],[n,n])
399
397
call check(trace(a) == (1 + 25 + 81 ), &
400
398
msg= " trace(a) == (1 + 25 + 81) failed." ,warn= warn)
401
- end subroutine
399
+ end subroutine test_trace_int8
402
400
403
401
subroutine test_trace_int16
404
402
integer , parameter :: n = 3
@@ -408,7 +406,7 @@ subroutine test_trace_int16
408
406
a = reshape ([(i** 3 ,i= 1 ,n** 2 )],[n,n])
409
407
call check(trace(a) == (1 + 125 + 729 ), &
410
408
msg= " trace(a) == (1 + 125 + 729) failed." ,warn= warn)
411
- end subroutine
409
+ end subroutine test_trace_int16
412
410
413
411
subroutine test_trace_int32
414
412
integer , parameter :: n = 3
@@ -418,7 +416,7 @@ subroutine test_trace_int32
418
416
a = reshape ([(i** 4 ,i= 1 ,n** 2 )],[n,n])
419
417
call check(trace(a) == (1 + 625 + 6561 ), &
420
418
msg= " trace(a) == (1 + 625 + 6561) failed." ,warn= warn)
421
- end subroutine
419
+ end subroutine test_trace_int32
422
420
423
421
subroutine test_trace_int64
424
422
integer , parameter :: n = 5
@@ -442,7 +440,129 @@ subroutine test_trace_int64
442
440
call check(trace(h) == sum (c(0 :nd:2 )), &
443
441
msg= " trace(h) == sum(c(0:nd:2)) failed." ,warn= warn)
444
442
445
- end subroutine
443
+ end subroutine test_trace_int64
444
+
445
+
446
+ subroutine test_outer_product_rsp
447
+ integer , parameter :: n = 2
448
+ real (sp) :: u(n), v(n), expected(n,n), diff(n,n)
449
+ write (* ,* ) " test_outer_product_rsp"
450
+ u = [1 .,2 .]
451
+ v = [1 .,3 .]
452
+ expected = reshape ([1 .,2 .,3 .,6 .],[n,n])
453
+ diff = expected - outer_product(u,v)
454
+ call check(all (abs (diff) < sptol), &
455
+ msg= " all(abs(diff) < sptol) failed." ,warn= warn)
456
+ end subroutine test_outer_product_rsp
457
+
458
+ subroutine test_outer_product_rdp
459
+ integer , parameter :: n = 2
460
+ real (dp) :: u(n), v(n), expected(n,n), diff(n,n)
461
+ write (* ,* ) " test_outer_product_rdp"
462
+ u = [1 .,2 .]
463
+ v = [1 .,3 .]
464
+ expected = reshape ([1 .,2 .,3 .,6 .],[n,n])
465
+ diff = expected - outer_product(u,v)
466
+ call check(all (abs (diff) < dptol), &
467
+ msg= " all(abs(diff) < dptol) failed." ,warn= warn)
468
+ end subroutine test_outer_product_rdp
469
+
470
+ subroutine test_outer_product_rqp
471
+ integer , parameter :: n = 2
472
+ real (qp) :: u(n), v(n), expected(n,n), diff(n,n)
473
+ write (* ,* ) " test_outer_product_rqp"
474
+ u = [1 .,2 .]
475
+ v = [1 .,3 .]
476
+ expected = reshape ([1 .,2 .,3 .,6 .],[n,n])
477
+ diff = expected - outer_product(u,v)
478
+ call check(all (abs (diff) < qptol), &
479
+ msg= " all(abs(diff) < qptol) failed." ,warn= warn)
480
+ end subroutine test_outer_product_rqp
481
+
482
+ subroutine test_outer_product_csp
483
+ integer , parameter :: n = 2
484
+ complex (sp) :: u(n), v(n), expected(n,n), diff(n,n)
485
+ write (* ,* ) " test_outer_product_csp"
486
+ u = [cmplx (1 .,1 .),cmplx (2 .,0 .)]
487
+ v = [cmplx (1 .,0 .),cmplx (3 .,1 .)]
488
+ expected = reshape ([cmplx (1 .,1 .),cmplx (2 .,0 .),cmplx (2 .,4 .),cmplx (6 .,2 .)],[n,n])
489
+ diff = expected - outer_product(u,v)
490
+ call check(all (abs (diff) < sptol), &
491
+ msg= " all(abs(diff) < sptol) failed." ,warn= warn)
492
+ end subroutine test_outer_product_csp
493
+
494
+ subroutine test_outer_product_cdp
495
+ integer , parameter :: n = 2
496
+ complex (dp) :: u(n), v(n), expected(n,n), diff(n,n)
497
+ write (* ,* ) " test_outer_product_cdp"
498
+ u = [cmplx (1 .,1 .),cmplx (2 .,0 .)]
499
+ v = [cmplx (1 .,0 .),cmplx (3 .,1 .)]
500
+ expected = reshape ([cmplx (1 .,1 .),cmplx (2 .,0 .),cmplx (2 .,4 .),cmplx (6 .,2 .)],[n,n])
501
+ diff = expected - outer_product(u,v)
502
+ call check(all (abs (diff) < dptol), &
503
+ msg= " all(abs(diff) < dptol) failed." ,warn= warn)
504
+ end subroutine test_outer_product_cdp
505
+
506
+ subroutine test_outer_product_cqp
507
+ integer , parameter :: n = 2
508
+ complex (qp) :: u(n), v(n), expected(n,n), diff(n,n)
509
+ write (* ,* ) " test_outer_product_cqp"
510
+ u = [cmplx (1 .,1 .),cmplx (2 .,0 .)]
511
+ v = [cmplx (1 .,0 .),cmplx (3 .,1 .)]
512
+ expected = reshape ([cmplx (1 .,1 .),cmplx (2 .,0 .),cmplx (2 .,4 .),cmplx (6 .,2 .)],[n,n])
513
+ diff = expected - outer_product(u,v)
514
+ call check(all (abs (diff) < qptol), &
515
+ msg= " all(abs(diff) < qptol) failed." ,warn= warn)
516
+ end subroutine test_outer_product_cqp
517
+
518
+ subroutine test_outer_product_int8
519
+ integer , parameter :: n = 2
520
+ integer (int8) :: u(n), v(n), expected(n,n), diff(n,n)
521
+ write (* ,* ) " test_outer_product_int8"
522
+ u = [1 ,2 ]
523
+ v = [1 ,3 ]
524
+ expected = reshape ([1 ,2 ,3 ,6 ],[n,n])
525
+ diff = expected - outer_product(u,v)
526
+ call check(all (abs (diff) == 0 ), &
527
+ msg= " all(abs(diff) == 0) failed." ,warn= warn)
528
+ end subroutine test_outer_product_int8
529
+
530
+ subroutine test_outer_product_int16
531
+ integer , parameter :: n = 2
532
+ integer (int16) :: u(n), v(n), expected(n,n), diff(n,n)
533
+ write (* ,* ) " test_outer_product_int16"
534
+ u = [1 ,2 ]
535
+ v = [1 ,3 ]
536
+ expected = reshape ([1 ,2 ,3 ,6 ],[n,n])
537
+ diff = expected - outer_product(u,v)
538
+ call check(all (abs (diff) == 0 ), &
539
+ msg= " all(abs(diff) == 0) failed." ,warn= warn)
540
+ end subroutine test_outer_product_int16
541
+
542
+ subroutine test_outer_product_int32
543
+ integer , parameter :: n = 2
544
+ integer (int32) :: u(n), v(n), expected(n,n), diff(n,n)
545
+ write (* ,* ) " test_outer_product_int32"
546
+ u = [1 ,2 ]
547
+ v = [1 ,3 ]
548
+ expected = reshape ([1 ,2 ,3 ,6 ],[n,n])
549
+ diff = expected - outer_product(u,v)
550
+ call check(all (abs (diff) == 0 ), &
551
+ msg= " all(abs(diff) == 0) failed." ,warn= warn)
552
+ end subroutine test_outer_product_int32
553
+
554
+ subroutine test_outer_product_int64
555
+ integer , parameter :: n = 2
556
+ integer (int64) :: u(n), v(n), expected(n,n), diff(n,n)
557
+ write (* ,* ) " test_outer_product_int64"
558
+ u = [1 ,2 ]
559
+ v = [1 ,3 ]
560
+ expected = reshape ([1 ,2 ,3 ,6 ],[n,n])
561
+ diff = expected - outer_product(u,v)
562
+ call check(all (abs (diff) == 0 ), &
563
+ msg= " all(abs(diff) == 0) failed." ,warn= warn)
564
+ end subroutine test_outer_product_int64
565
+
446
566
447
567
pure recursive function catalan_number(n) result(value)
448
568
integer , intent (in ) :: n
0 commit comments