Skip to content

Commit ac1657e

Browse files
authored
Add tests for gemmEx in fast math mode (#2660)
1 parent b324a8c commit ac1657e

File tree

1 file changed

+34
-0
lines changed
  • test/libraries/cublas/level3

1 file changed

+34
-0
lines changed

test/libraries/cublas/level3/gemm.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ k = 13
401401
end
402402
end
403403

404+
starting_mode = CUDA.math_mode()
405+
starting_precision = CUDA.math_precision()
404406
@testset "mixed-precision matmul" begin
405407
m,k,n = 4,4,4
406408
cudaTypes = (Float16, Complex{Float16}, BFloat16, Complex{BFloat16}, Float32, Complex{Float32},
@@ -432,6 +434,38 @@ k = 13
432434
@test C Array(dC) rtol=rtol
433435
end
434436
end
437+
try
438+
# test in fast math mode too
439+
for precision in (:Float16, :BFloat16, :TensorFloat32), (AT, CT) in ((Float32, Float32), (ComplexF32, ComplexF32))
440+
CUDA.math_mode!(CUDA.FAST_MATH; precision=precision)
441+
BT = AT # gemmEx requires identical A and B types
442+
443+
# we only test combinations of types that are supported by gemmEx
444+
if CUBLAS.gemmExComputeType(AT, BT, CT, m,k,n) !== nothing
445+
A = AT <: BFloat16 ? AT.(rand(m,k)) : rand(AT, m,k)
446+
B = BT <: BFloat16 ? BT.(rand(k,n)) : rand(BT, k,n)
447+
C = similar(B, CT)
448+
mul!(C, A, B)
449+
450+
# Base can't do Int8*Int8 without losing accuracy
451+
if (AT == Int8 && BT == Int8) || (AT == Complex{Int8} && BT == Complex{Int8})
452+
C = CT.(A) * CT.(B)
453+
end
454+
455+
dA = CuArray(A)
456+
dB = CuArray(B)
457+
dC = similar(dB, CT)
458+
mul!(dC, dA, dB)
459+
460+
rtol = Base.rtoldefault(AT, BT, 0)
461+
@test C Array(dC) rtol=rtol
462+
end
463+
end
464+
CUDA.math_mode!(CUDA.FAST_MATH; precision = :Bad)
465+
@test_throws ArgumentError("Unknown reduced precision type Bad") CUBLAS.gemmExComputeType(Float32, Float32, Float32, m, k, n)
466+
finally
467+
CUDA.math_mode!(starting_mode; precision = starting_precision)
468+
end
435469

436470
# also test an unsupported combination (falling back to GPUArrays)
437471
if VERSION < v"1.11-" # JuliaGPU/CUDA.jl#2441

0 commit comments

Comments
 (0)