24
24
@test testf (* , transpose (rand (elty, m, n)), rand (elty, m))
25
25
@test testf (* , rand (elty, m, n)' , rand (elty, m))
26
26
x = rand (elty, m)
27
- A = rand (elty, m, m + 1 )
28
- y = rand (elty, m )
27
+ A = rand (elty, m, m + 1 )
28
+ y = rand (elty, n )
29
29
dx = CuArray (x)
30
30
dA = CuArray (A)
31
31
dy = CuArray (y)
@@ -44,6 +44,10 @@ k = 13
44
44
dy = CUBLAS. gemv (' N' , dA, dx)
45
45
hy = collect (dy)
46
46
@test hy ≈ A * x
47
+ dy = CuArray (y)
48
+ dx = CUBLAS. gemv (elty <: Real ? ' T' : ' C' , alpha, dA, dy)
49
+ hx = collect (dx)
50
+ @test hx ≈ alpha * A' * y
47
51
end
48
52
49
53
if CUBLAS. version () >= v " 11.9"
@@ -72,6 +76,16 @@ k = 13
72
76
y[i] = alpha * A[i] * x[i] + beta * y[i]
73
77
@test y[i] ≈ hy
74
78
end
79
+ dy = CuArray{elty, 1 }[]
80
+ for i= 1 : length (A)
81
+ push! (dy, CuArray (y[i]))
82
+ end
83
+ CUBLAS. gemv_batched! (elty <: Real ? ' T' : ' C' , alpha, dA, dy, beta, dx)
84
+ for i in 1 : length (A)
85
+ hx = collect (dx[i])
86
+ x[i] = alpha * A[i]' * y[i] + beta * x[i]
87
+ @test x[i] ≈ hx
88
+ end
75
89
end
76
90
end
77
91
@@ -92,11 +106,18 @@ k = 13
92
106
dbad = CuArray (bad)
93
107
@test_throws DimensionMismatch CUBLAS. gemv_strided_batched! (' N' , alpha, dA, dx, beta, dbad)
94
108
CUBLAS. gemv_strided_batched! (' N' , alpha, dA, dx, beta, dy)
95
- for i= 1 : size (A, 3 )
109
+ for i in 1 : size (A, 3 )
96
110
hy = collect (dy[:, i])
97
111
y[:, i] = alpha * A[:, :, i] * x[:, i] + beta * y[:, i]
98
112
@test y[:, i] ≈ hy
99
113
end
114
+ dy = CuArray (y)
115
+ CUBLAS. gemv_strided_batched! (elty <: Real ? ' T' : ' C' , alpha, dA, dy, beta, dx)
116
+ for i in 1 : size (A, 3 )
117
+ hx = collect (dx[:, i])
118
+ x[:, i] = alpha * A[:, :, i]' * y[:, i] + beta * x[:, i]
119
+ @test x[:, i] ≈ hx
120
+ end
100
121
end
101
122
end
102
123
0 commit comments