@@ -401,6 +401,8 @@ k = 13
401
401
end
402
402
end
403
403
404
+ starting_mode = CUDA. math_mode ()
405
+ starting_precision = CUDA. math_precision ()
404
406
@testset " mixed-precision matmul" begin
405
407
m,k,n = 4 ,4 ,4
406
408
cudaTypes = (Float16, Complex{Float16}, BFloat16, Complex{BFloat16}, Float32, Complex{Float32},
@@ -432,6 +434,38 @@ k = 13
432
434
@test C ≈ Array (dC) rtol= rtol
433
435
end
434
436
end
437
+ try
438
+ # test in fast math mode too
439
+ for precision in (:Float16 , :BFloat16 , :TensorFloat32 ), (AT, CT) in ((Float32, Float32), (ComplexF32, ComplexF32))
440
+ CUDA. math_mode! (CUDA. FAST_MATH; precision= precision)
441
+ BT = AT # gemmEx requires identical A and B types
442
+
443
+ # we only test combinations of types that are supported by gemmEx
444
+ if CUBLAS. gemmExComputeType (AT, BT, CT, m,k,n) != = nothing
445
+ A = AT <: BFloat16 ? AT .(rand (m,k)) : rand (AT, m,k)
446
+ B = BT <: BFloat16 ? BT .(rand (k,n)) : rand (BT, k,n)
447
+ C = similar (B, CT)
448
+ mul! (C, A, B)
449
+
450
+ # Base can't do Int8*Int8 without losing accuracy
451
+ if (AT == Int8 && BT == Int8) || (AT == Complex{Int8} && BT == Complex{Int8})
452
+ C = CT .(A) * CT .(B)
453
+ end
454
+
455
+ dA = CuArray (A)
456
+ dB = CuArray (B)
457
+ dC = similar (dB, CT)
458
+ mul! (dC, dA, dB)
459
+
460
+ rtol = Base. rtoldefault (AT, BT, 0 )
461
+ @test C ≈ Array (dC) rtol= rtol
462
+ end
463
+ end
464
+ CUDA. math_mode! (CUDA. FAST_MATH; precision = :Bad )
465
+ @test_throws ArgumentError (" Unknown reduced precision type Bad" ) CUBLAS. gemmExComputeType (Float32, Float32, Float32, m, k, n)
466
+ finally
467
+ CUDA. math_mode! (starting_mode; precision = starting_precision)
468
+ end
435
469
436
470
# also test an unsupported combination (falling back to GPUArrays)
437
471
if VERSION < v " 1.11-" # JuliaGPU/CUDA.jl#2441
0 commit comments