Skip to content

Commit 1a439d6

Browse files
committed
Makes fallback float8 1x128 by 128x128 gemm output bfloat16
Summary: For now, we just care about bf16 output. We can add fp32 and a flag to control it later, if needed. Test Plan: ``` pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -k fp8_linear_variants -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b3c443c ghstack-comment-id: 3469836810 Pull-Request: #3265
1 parent 16de414 commit 1a439d6

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

test/kernel/test_blockwise_triton.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test_blockwise_fp8_gemm(M, N, K, dtype):
6666
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
6767
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
6868
C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)
69+
assert C_q.dtype == torch.bfloat16, "unsupported"
6970
error = torch.linalg.vector_norm(C - C_q) / torch.linalg.vector_norm(C)
7071
print(f"Relative Error: {error.item():.6f}")
7172

torchao/kernel/blockwise_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def blockwise_fp8_gemm(
9292
M = a.numel() // K
9393
N = b.size(0)
9494
M_BUCKET = math.ceil(math.log2(M))
95-
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
95+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.bfloat16)
9696
grid = lambda META: (
9797
triton.cdiv(M, META["BLOCK_SIZE_M"]),
9898
triton.cdiv(N, META["BLOCK_SIZE_N"]),
@@ -105,7 +105,7 @@ def blockwise_fp8_gemm(
105105
@blockwise_fp8_gemm.register_fake
106106
def _(a, a_s, b, b_s, block_size=128):
107107
N = b.size(0)
108-
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
108+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.bfloat16)
109109
return c
110110

111111
@triton.jit

0 commit comments

Comments
 (0)