@@ -7,6 +7,7 @@ module test_intrinsics
7
7
use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test
8
8
use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64
9
9
use stdlib_intrinsics
10
+ use stdlib_linalg_state, only: linalg_state_type, LINALG_VALUE_ERROR, operator(==)
10
11
use stdlib_math, only: swap
11
12
implicit none
12
13
@@ -19,7 +20,8 @@ subroutine collect_suite(testsuite)
19
20
20
21
testsuite = [ &
21
22
new_unittest('sum', test_sum), &
22
- new_unittest('dot_product', test_dot_product) &
23
+ new_unittest('dot_product', test_dot_product), &
24
+ new_unittest('matmul', test_matmul) &
23
25
]
24
26
end subroutine
25
27
@@ -249,6 +251,45 @@ subroutine test_dot_product(error)
249
251
#:endfor
250
252
251
253
end subroutine
254
+
255
+ subroutine test_matmul(error)
256
+ type(error_type), allocatable, intent(out) :: error
257
+ type(linalg_state_type) :: linerr
258
+ real :: a(2, 3), b(3, 4), c(3, 2), d(2, 2)
259
+
260
+ d = stdlib_matmul(a, b, c, err=linerr)
261
+ call check(error, linerr == LINALG_VALUE_ERROR, "incompatible matrices are considered compatible")
262
+ if (allocated(error)) return
263
+
264
+ #:for k, t, s in R_KINDS_TYPES
265
+ block
266
+ ${t}$ :: x(10,20), y(20,30), z(30,10), r(10,10), r1(10,10)
267
+ call random_number(x)
268
+ call random_number(y)
269
+ call random_number(z)
270
+
271
+ r = stdlib_matmul(x, y, z) ! the optimal ordering would be (x(yz))
272
+ r1 = matmul(matmul(x, y), z) ! the opposite order to induce a difference
273
+
274
+ call check(error, all(abs(r-r1) <= epsilon(0._${k}$) * 300), "real, ${k}$, 3 args: error too large")
275
+ if (allocated(error)) return
276
+ end block
277
+
278
+ block
279
+ ${t}$ :: x(10,20), y(20,30), z(30,10), w(10, 20), r(10,20), r1(10,20)
280
+ call random_number(x)
281
+ call random_number(y)
282
+ call random_number(z)
283
+ call random_number(w)
284
+
285
+ r = stdlib_matmul(x, y, z, w) ! the optimal order would be ((x(yz))w)
286
+ r1 = matmul(matmul(x, y), matmul(z, w))
287
+
288
+ call check(error, all(abs(r-r1) <= epsilon(0._${k}$) * 1500), "real, ${k}$, 4 args: error too large")
289
+ if (allocated(error)) return
290
+ end block
291
+ #:endfor
292
+ end subroutine test_matmul
252
293
253
294
end module test_intrinsics
254
295
@@ -276,4 +317,4 @@ program tester
276
317
write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
277
318
error stop
278
319
end if
279
- end program
320
+ end program
0 commit comments