@@ -332,7 +332,7 @@ const TILE_DIM = 16
332
332
# legacy method
333
333
generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
334
334
generic_matmatmul! (C, A, B, MulAddMul (a, b))
335
- function generic_matmatmul! (C:: AbstractArray {R} , A:: AbstractArray {T} , B:: AbstractArray {S} , add:: MulAddMul ) where {T,S,R}
335
+ function generic_matmatmul! (C:: AbstractGPUMatrix {R} , A:: AbstractGPUMatrix {T} , B:: AbstractGPUMatrix {S} , add:: MulAddMul ) where {T,S,R}
336
336
N = size (A,1 )
337
337
Q = size (A,2 )
338
338
M = size (B,2 )
@@ -347,7 +347,7 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
347
347
end
348
348
349
349
@kernel unsafe_indices= true function coalesced_matmul_kernel! (
350
- output, @Const ( input1), @Const ( input2) , N, Q, M,
350
+ output, input1, input2, N, Q, M,
351
351
:: Val{BANK} = Val (1 ),
352
352
) where {BANK}
353
353
grow, gcol = @index (Group, NTuple)
@@ -363,7 +363,6 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
363
363
outval = @private R 1
364
364
@inbounds outval[1 ] = - zero (R)
365
365
366
- # @uniform N = size(output, 1)
367
366
# number of tiles depends on inner dimension
368
367
@uniform NUM_TILES = div (Q + TILE_DIM - 1 , TILE_DIM)
369
368
@@ -406,6 +405,34 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
406
405
coalesced_matmul_kernel! (get_backend (C), (TILE_DIM, TILE_DIM))(C, A, B, N, Q, M;ndrange= map (x -> ceil (Int,x/ TILE_DIM)* TILE_DIM, size (C)))
407
406
C
408
407
end
408
+ function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
409
+ if size (A,2 ) != size (B,1 )
410
+ throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
411
+ end
412
+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
413
+ throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((size (A,1 ),size (B,2 ))) " ))
414
+ end
415
+ if isempty (A) || isempty (B)
416
+ return fill! (C, zero (R))
417
+ end
418
+
419
+ @kernel function matmatmul_kernel! (C, A, B)
420
+ assume .(size (C) .> 0 )
421
+ idx = @index (Global, Linear)
422
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
423
+
424
+ @inbounds if i <= size (A,1 ) && j <= size (B,2 )
425
+ z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
426
+ Cij = convert (promote_type (R, typeof (z2)), z2)
427
+ for k in 1 : size (A,2 )
428
+ Cij += A[i, k]* B[k, j]
429
+ end
430
+ C[i,j] = add (Cij, C[i,j])
431
+ end
432
+ end
433
+ matmatmul_kernel! (get_backend (C))(C, A, B; ndrange = size (C))
434
+ C
435
+ end
409
436
410
437
@static if VERSION < v " 1.12.0-"
411
438
function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , _add:: MulAddMul = MulAddMul ())
0 commit comments