Skip to content

Commit 86b255e

Browse files
authored
Bugfix for batched gemv (#2481)
Fix and add test for when the matrix is transposed.
1 parent 69f3a76 commit 86b255e

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

lib/cublas/wrappers.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -424,15 +424,16 @@ for (fname, fname_64, eltyin, eltyout) in (
424424
if length(A) != length(x) || length(A) != length(y)
425425
throw(DimensionMismatch("Lengths of inputs must be the same"))
426426
end
427+
m = size(A[1], 1)
428+
n = size(A[1], 2)
427429
for (i, (As,xs,ys)) in enumerate(zip(A,x,y))
428-
m,n = size(As)
430+
if size(As) != (m, n)
431+
throw(DimensionMismatch("A[$i] has different dimension from A[1]. Dimensions between A's should be identical."))
432+
end
429433
if length(xs) != (trans == 'N' ? n : m) || length(ys) != (trans == 'N' ? m : n)
430434
throw(DimensionMismatch("Input $i: A has dimension $(size(As)), x has dimension $(size(xs)), y has dimension $(size(ys))"))
431435
end
432436
end
433-
434-
m = size(A[1], trans == 'N' ? 1 : 2)
435-
n = size(A[1], trans == 'N' ? 2 : 1)
436437
lda = max(1,stride(A[1],2))
437438
incx = stride(x[1],1)
438439
incy = stride(y[1],1)
@@ -470,9 +471,9 @@ for (fname, fname_64, eltyin, eltyout) in (
470471
if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2)
471472
throw(DimensionMismatch("Batch sizes must be equal for all inputs"))
472473
end
473-
m = size(A, trans == 'N' ? 1 : 2)
474-
n = size(A, trans == 'N' ? 2 : 1)
475-
if m != size(y, 1) || n != size(x, 1)
474+
m = size(A, 1)
475+
n = size(A, 2)
476+
if size(y, 1) != (trans == 'N' ? m : n) || size(x, 1) != (trans == 'N' ? n : m)
476477
throw(DimensionMismatch("A has dimension $(size(A)), x has dimension $(size(x)), y has dimension $(size(y))"))
477478
end
478479

test/libraries/cublas/level2.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ k = 13
2424
@test testf(*, transpose(rand(elty, m, n)), rand(elty, m))
2525
@test testf(*, rand(elty, m, n)', rand(elty, m))
2626
x = rand(elty, m)
27-
A = rand(elty, m, m + 1 )
28-
y = rand(elty, m)
27+
A = rand(elty, m, m + 1)
28+
y = rand(elty, n)
2929
dx = CuArray(x)
3030
dA = CuArray(A)
3131
dy = CuArray(y)
@@ -44,6 +44,10 @@ k = 13
4444
dy = CUBLAS.gemv('N', dA, dx)
4545
hy = collect(dy)
4646
@test hy A * x
47+
dy = CuArray(y)
48+
dx = CUBLAS.gemv(elty <: Real ? 'T' : 'C', alpha, dA, dy)
49+
hx = collect(dx)
50+
@test hx alpha * A' * y
4751
end
4852

4953
if CUBLAS.version() >= v"11.9"
@@ -72,6 +76,16 @@ k = 13
7276
y[i] = alpha * A[i] * x[i] + beta * y[i]
7377
@test y[i] hy
7478
end
79+
dy = CuArray{elty, 1}[]
80+
for i=1:length(A)
81+
push!(dy, CuArray(y[i]))
82+
end
83+
CUBLAS.gemv_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx)
84+
for i in 1:length(A)
85+
hx = collect(dx[i])
86+
x[i] = alpha * A[i]' * y[i] + beta * x[i]
87+
@test x[i] hx
88+
end
7589
end
7690
end
7791

@@ -92,11 +106,18 @@ k = 13
92106
dbad = CuArray(bad)
93107
@test_throws DimensionMismatch CUBLAS.gemv_strided_batched!('N', alpha, dA, dx, beta, dbad)
94108
CUBLAS.gemv_strided_batched!('N', alpha, dA, dx, beta, dy)
95-
for i=1:size(A, 3)
109+
for i in 1:size(A, 3)
96110
hy = collect(dy[:, i])
97111
y[:, i] = alpha * A[:, :, i] * x[:, i] + beta * y[:, i]
98112
@test y[:, i] hy
99113
end
114+
dy = CuArray(y)
115+
CUBLAS.gemv_strided_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx)
116+
for i in 1:size(A, 3)
117+
hx = collect(dx[:, i])
118+
x[:, i] = alpha * A[:, :, i]' * y[:, i] + beta * x[:, i]
119+
@test x[:, i] hx
120+
end
100121
end
101122
end
102123

0 commit comments

Comments
 (0)