@@ -325,37 +325,85 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
325
325
B
326
326
end
327
327
328
+ # XXX : figure out how to do dynamically
329
+ const TILE_DIM = 16
328
330
329
331
# # matrix multiplication
330
332
# legacy method
331
333
generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
332
334
generic_matmatmul! (C, A, B, MulAddMul (a, b))
333
335
function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
334
- if size (A,2 ) != size (B,1 )
336
+ N = size (A,1 )
337
+ Q = size (A,2 )
338
+ M = size (B,2 )
339
+ if Q != size (B,1 )
335
340
throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
336
341
end
337
- if size (C,1 ) != size (A, 1 ) || size (C,2 ) != size (B, 2 )
338
- throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((size (A, 1 ), size (B, 2 ) )) " ))
342
+ if size (C,1 ) != N || size (C,2 ) != M
343
+ throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((N,M )) " ))
339
344
end
340
345
if isempty (A) || isempty (B)
341
346
return fill! (C, zero (R))
342
347
end
343
348
344
- @kernel function matmatmul_kernel! (C, A, B)
345
- assume .(size (C) .> 0 )
346
- idx = @index (Global, Linear)
347
- i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
349
+ @kernel unsafe_indices= true function coalesced_matmul_kernel! (
350
+ output, @Const (input1), @Const (input2), N, Q, M,
351
+ :: Val{BANK} = Val (1 ),
352
+ ) where {BANK}
353
+ grow, gcol = @index (Group, NTuple)
354
+ tile_row, tile_col = @index (Local, NTuple)
348
355
349
- @inbounds if i <= size (A,1 ) && j <= size (B,2 )
350
- z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
351
- Cij = convert (promote_type (R, typeof (z2)), z2)
352
- for k in 1 : size (A,2 )
353
- Cij += A[i, k]* B[k, j]
356
+ # TILE_DIM = @uniform @groupsize()[1]
357
+
358
+ # +1 to avoid bank conflicts on shared memory
359
+ tile1 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
360
+ tile2 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
361
+
362
+ # private variable for tile output
363
+ outval = @private R 1
364
+ @inbounds outval[1 ] = - zero (R)
365
+
366
+ # @uniform N = size(output, 1)
367
+ # number of tiles depends on inner dimension
368
+ @uniform NUM_TILES = div (Q + TILE_DIM - 1 , TILE_DIM)
369
+
370
+ I = (grow - 1 ) * TILE_DIM + tile_row
371
+ J = (gcol - 1 ) * TILE_DIM + tile_col
372
+
373
+ # loop over all tiles needed for this calculation
374
+ for t in 0 : (NUM_TILES - 1 )
375
+ # load inputs into tiles, with bounds checking for non-square matrices
376
+ if I <= N && t * TILE_DIM + tile_col <= Q
377
+ @inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
378
+ else
379
+ @inbounds tile1[tile_row, tile_col] = zero (R)
380
+ end
381
+ if J <= M && t * TILE_DIM + tile_row <= Q
382
+ @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
383
+ else
384
+ @inbounds tile2[tile_row, tile_col] = zero (R)
354
385
end
355
- C[i,j] = add (Cij, C[i,j])
386
+
387
+ # wait for all tiles to be loaded
388
+ @synchronize
389
+
390
+ # calculate value of spot in output, use temporary value to allow for vectorization
391
+ out = zero (R)
392
+ @simd for k in 1 : TILE_DIM
393
+ @inbounds out += tile1[tile_row, k] * tile2[k, tile_col]
394
+ end
395
+ outval[1 ] += out
396
+
397
+ @synchronize
398
+ end
399
+
400
+ # save if inbounds
401
+ if I <= N && J <= M
402
+ @inbounds output[I, J] = add (outval[1 ], output[I, J])
356
403
end
357
404
end
358
- matmatmul_kernel! (get_backend (C))(C, A, B; ndrange = size (C))
405
+
406
+ coalesced_matmul_kernel! (get_backend (C), (TILE_DIM, TILE_DIM))(C, A, B, N, Q, M;ndrange= map (x -> ceil (Int,x/ TILE_DIM)* TILE_DIM, size (C)))
359
407
C
360
408
end
361
409
@@ -744,7 +792,7 @@ function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2},
744
792
745
793
@kernel function kron_kernel! (z, @Const (x), @Const (y))
746
794
i, j = @index (Global, NTuple)
747
-
795
+
748
796
@inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
749
797
end
750
798
@@ -777,13 +825,13 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
777
825
778
826
ta = $ transa (T1)
779
827
tb = $ transb (T2)
780
-
828
+
781
829
@kernel function kron_kernel! (C, @Const (A), @Const (B))
782
830
ai, aj = @index (Global, NTuple) # Indices in the result matrix
783
-
831
+
784
832
# lb1, lb2 = size(B) # Dimensions of B
785
833
lb1, lb2 = tb == ' N' ? size (B) : reverse (size (B))
786
-
834
+
787
835
# Map global indices (ai, aj) to submatrices of the Kronecker product
788
836
i_a = (ai - 1 ) ÷ lb1 + 1 # Corresponding row index in A
789
837
i_b = (ai - 1 ) % lb1 + 1 # Corresponding row index in B
@@ -797,12 +845,12 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
797
845
C[ai, aj] = a_ij * b_ij
798
846
end
799
847
end
800
-
848
+
801
849
backend = KernelAbstractions. get_backend (C)
802
850
kernel = kron_kernel! (backend)
803
-
851
+
804
852
kernel (C, $ (unwrapa (:A )), $ (unwrapb (:B )), ndrange= (size (C, 1 ), size (C, 2 )))
805
-
853
+
806
854
return C
807
855
end
808
856
0 commit comments