diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 2cad3f2763..0beb261969 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -1180,6 +1180,7 @@ def matmul_fp8_row( tma_persistent: bool = True, no_use_persistent: Optional[bool] = None, use_warp_specialization: bool = False, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N]. @@ -1195,6 +1196,7 @@ def matmul_fp8_row( allow_tf32 (bool): Whether to use TF32 for tensor core. fp8_fast_accum (bool): Whether to use fast accumulation for tensor core. tma_persistent (bool): Whether to use TMA persistent kernel impl. + out_dtype (torch.dtype): Output type of the function. Returns: torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :]) @@ -1217,7 +1219,7 @@ def matmul_fp8_row( assert a.dtype in (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz) assert b.dtype == pt_fp8_dtype M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = ( - prep_matmul(a, b, dot_out_dtype) + prep_matmul(a, b, dot_out_dtype, out_dtype) ) # Skip scaling (a_scale is None) can only be applied in certain cases. @@ -1622,6 +1624,7 @@ def matmul_fp8_row_meta( tma_persistent: bool = True, no_use_persistent: Optional[bool] = None, use_warp_specialization: bool = False, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """Shape function for torch compile.""" M, K = a.shape @@ -2185,6 +2188,7 @@ def prep_matmul( a: Union[TensorWrapper, torch.Tensor], b: Union[TensorWrapper, torch.Tensor], dot_out_dtype: Optional[torch.dtype], + out_dtype: Optional[torch.dtype] = None, ) -> Tuple[ int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device ]: @@ -2239,10 +2243,18 @@ def prep_matmul( tl.float8e5, tl.float8e4b8, ] - c_dtype = torch.bfloat16 - c_dtype_triton = tl.bfloat16 - c = torch.empty((M, N), device=device, dtype=c_dtype) + if out_dtype is None: + # Default value torch.bfloat16 is to accommodate + # legacy usage of the function. + out_dtype = torch.bfloat16 + assert isinstance( + out_dtype, torch.dtype + ), f"out_dtype type {type(out_dtype)} must be a torch.dtype" + c_dtype_triton = map_dtype_to_triton(out_dtype) + + c = torch.empty((M, N), device=device, dtype=out_dtype) + if dot_out_dtype is None: dot_out_dtype_triton = tl.float32 else: