Skip to content

Commit d0cd3be

Browse files
add more generic kernel for fp8 blockwise scaling
1 parent 0e00df3 commit d0cd3be

File tree

6 files changed

+392
-76
lines changed

6 files changed

+392
-76
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.testing import do_bench
1414

1515
from torchao.float8.float8_utils import compute_error
16-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
16+
from torchao.prototype.blockwise_fp8.kernels import (
1717
blockwise_fp8_gemm,
1818
fp8_blockwise_act_quant,
1919
fp8_blockwise_weight_quant,

test/prototype/test_blockwise_triton.py

Lines changed: 0 additions & 72 deletions
This file was deleted.
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
from packaging import version
11+
from torchao.prototype.blockwise_fp8.kernels import (
12+
torch_blockwise_scale_narrow_tiles,
13+
torch_blockwise_scale_square_tiles,
14+
triton_quantize_fp8_block,
15+
)
16+
from torchao.testing.utils import skip_if_rocm
17+
18+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
19+
20+
from torchao.prototype.blockwise_fp8.kernels import (
21+
blockwise_fp8_gemm,
22+
fp8_blockwise_act_quant,
23+
fp8_blockwise_weight_dequant,
24+
fp8_blockwise_weight_quant,
25+
)
26+
from torchao.utils import is_sm_at_least_89
27+
28+
BLOCKWISE_SIZE_MNK = [
29+
(2, 512, 128),
30+
(3, 2048, 2048),
31+
(4, 3584, 640),
32+
(13, 8704, 8576),
33+
(26, 18944, 1664),
34+
(67, 6656, 1408),
35+
]
36+
37+
38+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
39+
@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK)
40+
@pytest.mark.parametrize(
41+
"dtype",
42+
[torch.float8_e4m3fn, torch.float8_e5m2]
43+
if is_sm_at_least_89()
44+
else [torch.float8_e5m2],
45+
)
46+
def test_blockwise_quant_dequant(_, N, K, dtype):
47+
x = torch.randn(N, K).cuda()
48+
qx, s = fp8_blockwise_weight_quant(x, dtype=dtype)
49+
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
50+
error = torch.norm(x - x_reconstructed) / torch.norm(x)
51+
print(f"Relative Error: {error.item():.6f}")
52+
53+
assert error < 0.1, "Quant-Dequant error is too high"
54+
55+
56+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
57+
@pytest.mark.skipif(
58+
version.parse(triton.__version__) < version.parse("3.3.0"),
59+
reason="Triton version < 3.3.0, test skipped",
60+
)
61+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
62+
@pytest.mark.parametrize(
63+
"dtype",
64+
[torch.float8_e4m3fn, torch.float8_e5m2]
65+
if is_sm_at_least_89()
66+
else [torch.float8_e5m2],
67+
)
68+
def test_blockwise_fp8_gemm(M, N, K, dtype):
69+
A = torch.randn(M, K).cuda()
70+
B = torch.randn(N, K).cuda()
71+
C = A @ B.T
72+
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
73+
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
74+
C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)
75+
error = torch.norm(C - C_q) / torch.norm(C)
76+
print(f"Relative Error: {error.item():.6f}")
77+
78+
assert error < 0.1, "Quantize gemm error is too high"
79+
80+
81+
@skip_if_rocm("ROCm enablement in progress")
82+
@pytest.mark.parametrize("tile_size", [128, 256])
83+
def test_triton_quantize_fp8_narrow_tiles(tile_size: int):
84+
device = "cuda"
85+
M, K = 256, 256
86+
x = torch.randn(M, K, device=device)
87+
88+
# Get the quantized tensor and scales using triton implementation
89+
# Use block_m=1 to match the narrow tiles (1 x tile_size) in the reference implementation
90+
triton_fp8, triton_scale = triton_quantize_fp8_block(
91+
x, block_m=1, block_k=tile_size
92+
)
93+
94+
# Get the quantized tensor and scales using reference implementation
95+
ref_fp8, ref_scale = torch_blockwise_scale_narrow_tiles(x, tile_size=tile_size)
96+
97+
# Convert both to float32 for comparison
98+
triton_fp32 = triton_fp8.to(torch.float32)
99+
ref_fp32 = ref_fp8.to(torch.float32)
100+
101+
# Check that the quantized tensors are close
102+
# Note: We use a relatively high tolerance because the implementations might have
103+
# slight differences in how they handle edge cases, rounding, etc.
104+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), (
105+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
106+
)
107+
108+
# Check that the scales are close
109+
# Note: The scales might be stored differently (reciprocal vs. direct), so we need to
110+
# be careful about how we compare them
111+
112+
# In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale)
113+
# In torch_blockwise_scale_narrow_tiles, scales are stored directly
114+
# So we need to take the reciprocal of one of them for comparison
115+
116+
# Reshape triton_scale to match ref_scale shape for comparison
117+
triton_scale_reshaped = triton_scale.reshape(M, -1)
118+
119+
# Compare reciprocal of triton_scale with ref_scale
120+
assert torch.allclose(
121+
1.0 / triton_scale_reshaped, ref_scale, rtol=1e-2, atol=1e-2
122+
), (
123+
f"Scales differ: max diff = {(1.0 / triton_scale_reshaped - ref_scale).abs().max().item()}"
124+
)
125+
126+
127+
@skip_if_rocm("ROCm enablement in progress")
128+
@pytest.mark.parametrize("tile_size", [128, 256])
129+
def test_triton_quantize_fp8_square_tiles(tile_size: int):
130+
device = "cuda"
131+
# Make sure dimensions are multiples of tile_size for clean comparison
132+
M = tile_size * 2
133+
K = tile_size * 2
134+
x = torch.randn(M, K, device=device)
135+
136+
# Get the quantized tensor and scales using triton implementation
137+
triton_fp8, triton_scale = triton_quantize_fp8_block(
138+
x, block_m=tile_size, block_k=tile_size
139+
)
140+
141+
# Get the quantized tensor and scales using reference implementation
142+
ref_fp8, ref_scale = torch_blockwise_scale_square_tiles(x, tile_size=tile_size)
143+
144+
# Convert both to float32 for comparison
145+
triton_fp32 = triton_fp8.to(torch.float32)
146+
ref_fp32 = ref_fp8.to(torch.float32)
147+
148+
# Check that the quantized tensors are close
149+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), (
150+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
151+
)
152+
153+
# Check that the scales are close
154+
# In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale)
155+
# In torch_blockwise_scale_square_tiles, scales are stored directly
156+
157+
# Compare reciprocal of triton_scale with ref_scale
158+
assert torch.allclose(1.0 / triton_scale, ref_scale, rtol=1e-2, atol=1e-2), (
159+
f"Scales differ: max diff = {(1.0 / triton_scale - ref_scale).abs().max().item()}"
160+
)

torchao/prototype/blockwise_fp8/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .blockwise_linear import BlockwiseQuantLinear
2-
from .blockwise_quantization import (
2+
from .kernels import (
33
blockwise_fp8_gemm,
44
fp8_blockwise_act_quant,
55
fp8_blockwise_weight_dequant,

torchao/prototype/blockwise_fp8/blockwise_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import nn
99

10-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
10+
from torchao.prototype.blockwise_fp8.kernels import (
1111
blockwise_fp8_gemm,
1212
fp8_blockwise_act_quant,
1313
)

0 commit comments

Comments
 (0)