Skip to content

Commit 7e72182

Browse files
committed
Restrict method to pass tests
1 parent 576c23d commit 7e72182

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

src/host/linalg.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ const TILE_DIM = 16
332332
# legacy method
333333
generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) =
334334
generic_matmatmul!(C, A, B, MulAddMul(a, b))
335-
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}
335+
function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B::AbstractGPUMatrix{S}, add::MulAddMul) where {T,S,R}
336336
N = size(A,1)
337337
Q = size(A,2)
338338
M = size(B,2)
@@ -347,7 +347,7 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
347347
end
348348

349349
@kernel unsafe_indices=true function coalesced_matmul_kernel!(
350-
output, @Const(input1), @Const(input2), N, Q, M,
350+
output, input1, input2, N, Q, M,
351351
::Val{BANK} = Val(1),
352352
) where {BANK}
353353
grow, gcol = @index(Group, NTuple)
@@ -363,7 +363,6 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
363363
outval = @private R 1
364364
@inbounds outval[1] = -zero(R)
365365

366-
# @uniform N = size(output, 1)
367366
# number of tiles depends on inner dimension
368367
@uniform NUM_TILES = div(Q + TILE_DIM - 1, TILE_DIM)
369368

@@ -406,6 +405,34 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
406405
coalesced_matmul_kernel!(get_backend(C), (TILE_DIM, TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/TILE_DIM)*TILE_DIM, size(C)))
407406
C
408407
end
408+
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}
409+
if size(A,2) != size(B,1)
410+
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
411+
end
412+
if size(C,1) != size(A,1) || size(C,2) != size(B,2)
413+
throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))"))
414+
end
415+
if isempty(A) || isempty(B)
416+
return fill!(C, zero(R))
417+
end
418+
419+
@kernel function matmatmul_kernel!(C, A, B)
420+
assume.(size(C) .> 0)
421+
idx = @index(Global, Linear)
422+
i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
423+
424+
@inbounds if i <= size(A,1) && j <= size(B,2)
425+
z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j])
426+
Cij = convert(promote_type(R, typeof(z2)), z2)
427+
for k in 1:size(A,2)
428+
Cij += A[i, k]*B[k, j]
429+
end
430+
C[i,j] = add(Cij, C[i,j])
431+
end
432+
end
433+
matmatmul_kernel!(get_backend(C))(C, A, B; ndrange = size(C))
434+
C
435+
end
409436

410437
@static if VERSION < v"1.12.0-"
411438
function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul())

0 commit comments

Comments
 (0)