Skip to content

Commit c7edfee

Browse files
authored
Fix generic matmatmul NaN handling. (#476)
1 parent 3994982 commit c7edfee

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/host/linalg.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,11 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
324324
return fill!(C, zero(R))
325325
end
326326

327+
add = MulAddMul(a, b)
328+
327329
gpu_call(C, A, B; name="matmatmul!") do ctx, C, A, B
328330
idx = @linearidx C
331+
assume.(size(C) .> 0)
329332
i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
330333

331334
@inbounds if i <= size(A,1) && j <= size(B,2)
@@ -334,7 +337,7 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
334337
for k in 1:size(A,2)
335338
Ctmp += A[i, k]*B[k, j]
336339
end
337-
C[i,j] = Ctmp*a + C[i,j]*b
340+
C[i,j] = add(Ctmp, C[i,j])
338341
end
339342

340343
return
@@ -382,7 +385,7 @@ function LinearAlgebra.herk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::
382385
end
383386
end
384387
end # VERSION
385-
388+
386389
function generic_rmul!(X::AbstractArray, s::Number)
387390
gpu_call(X, s; name="rmul!") do ctx, X, s
388391
i = @linearidx X

0 commit comments

Comments
 (0)