Skip to content

Commit 31d5c17

Browse files
ProExpertProgmgoin
andauthored
[Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf (#19830)
Signed-off-by: Luka Govedic <lgovedic@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com>
1 parent 35514b6 commit 31d5c17

File tree

18 files changed

+368
-104
lines changed

18 files changed

+368
-104
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import itertools
4+
from typing import Callable
5+
6+
import torch
7+
8+
from vllm import _custom_ops as ops
9+
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
10+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
11+
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
12+
from vllm.triton_utils import triton
13+
14+
15+
# TODO(luka): use standalone_compile utility
16+
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
17+
def inner(*args):
18+
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
19+
return fn(*args)
20+
21+
return inner
22+
23+
24+
torch._dynamo.config.recompile_limit = 8888
25+
compilation_config = CompilationConfig(custom_ops=["none"])
26+
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
27+
torch_per_token_quant_fp8 = torch.compile(
28+
QuantFP8(False, GroupShape.PER_TOKEN),
29+
fullgraph=True,
30+
dynamic=False, # recompile for different shapes
31+
)
32+
33+
# First dim is explicitly dynamic to simulate vLLM usage
34+
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
35+
36+
37+
def cuda_per_token_quant_fp8(
38+
input: torch.Tensor,
39+
) -> tuple[torch.Tensor, torch.Tensor]:
40+
return ops.scaled_fp8_quant(input)
41+
42+
43+
def calculate_diff(batch_size: int, seq_len: int):
44+
"""Calculate difference between Triton and CUDA implementations."""
45+
device = torch.device("cuda")
46+
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
47+
48+
torch_out, torch_scale = torch_per_token_quant_fp8(x)
49+
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
50+
51+
if torch.allclose(
52+
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
53+
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
54+
print("✅ All implementations match")
55+
else:
56+
print("❌ Implementations differ")
57+
58+
59+
batch_size_range = [1, 16, 32, 64, 128]
60+
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
61+
62+
configs = list(itertools.product(batch_size_range, seq_len_range))
63+
64+
65+
@triton.testing.perf_report(
66+
triton.testing.Benchmark(
67+
x_names=["batch_size", "seq_len"],
68+
x_vals=configs,
69+
line_arg="provider",
70+
line_vals=["torch", "cuda"],
71+
line_names=["Torch", "CUDA"],
72+
styles=[("blue", "-"), ("green", "-")],
73+
ylabel="us",
74+
plot_name="per-token-dynamic-quant-fp8-performance",
75+
args={},
76+
)
77+
)
78+
def benchmark_quantization(batch_size, seq_len, provider):
79+
dtype = torch.float16
80+
device = torch.device("cuda")
81+
82+
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
83+
84+
quantiles = [0.5, 0.2, 0.8]
85+
86+
if provider == "torch":
87+
fn = lambda: torch_per_token_quant_fp8(x.clone())
88+
elif provider == "cuda":
89+
fn = lambda: cuda_per_token_quant_fp8(x.clone())
90+
91+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
92+
93+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
94+
95+
96+
if __name__ == "__main__":
97+
calculate_diff(batch_size=4, seq_len=4096)
98+
benchmark_quantization.run(print_data=True)

tests/compile/test_fusion.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
4444
]
4545
self.fp8_linear = Fp8LinearOp(
4646
cutlass_fp8_supported=cutlass_fp8_enabled,
47-
use_per_token_if_dynamic=True)
47+
act_quant_static=static,
48+
act_quant_group_shape=group_shape,
49+
)
4850

4951
def forward(self, x):
5052
resid = torch.sqrt(x)
@@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
9193
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
9294

9395
vllm_config = VllmConfig(compilation_config=CompilationConfig(
94-
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
95-
vllm_config.compilation_config.pass_config = \
96-
PassConfig(enable_fusion=True, enable_noop=True)
96+
level=CompilationLevel.PIECEWISE,
97+
custom_ops=["+rms_norm", "+quant_fp8"],
98+
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
99+
))
97100
with vllm.config.set_current_vllm_config(vllm_config):
98101
# Reshape pass is needed for the fusion pass to work
99102
noop_pass = NoOpEliminationPass(vllm_config)

tests/compile/test_fusion_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
5050
# DYNAMO_ONCE does not properly propagate shapes.
5151
level=CompilationLevel.DYNAMO_AS_IS,
5252
backend="tests.compile.test_fusion_attn.backend_unfused",
53+
custom_ops=["+quant_fp8"],
5354
)
5455
vllm_config = VllmConfig(compilation_config=compile_config)
5556
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
@@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
7374
# DYNAMO_ONCE does not properly propagate shapes.
7475
level=CompilationLevel.DYNAMO_AS_IS,
7576
backend="tests.compile.test_fusion_attn.backend",
77+
custom_ops=["+quant_fp8"],
7678
)
7779
vllm_config = VllmConfig(compilation_config=compile_config)
7880

tests/compile/test_silu_mul_quant_fusion.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,56 @@
44
import torch
55

66
import vllm.envs as envs
7-
from vllm._custom_ops import scaled_fp8_quant
87
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
98
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
9+
from vllm.compilation.noop_elimination import NoOpEliminationPass
1010
from vllm.config import CompilationConfig, PassConfig, VllmConfig
1111
from vllm.model_executor.layers.activation import SiluAndMul
12+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
13+
GroupShape)
14+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
15+
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
16+
from vllm.platforms import current_platform
1217

1318
from .backend import TestBackend
1419

1520

1621
class TestModel(torch.nn.Module):
1722

18-
def __init__(self, *args, **kwargs):
23+
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
24+
**kwargs):
1925
super().__init__(*args, **kwargs)
2026
self.silu_and_mul = SiluAndMul()
27+
self.wscale = torch.rand(1, dtype=torch.float32)
2128
self.scale = torch.rand(1, dtype=torch.float32)
2229

30+
self.w = (torch.rand(
31+
hidden_size,
32+
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
33+
34+
self.fp8_linear = Fp8LinearOp(
35+
cutlass_fp8_supported=cutlass_fp8_enabled,
36+
act_quant_static=True,
37+
act_quant_group_shape=GroupShape.PER_TENSOR,
38+
)
39+
2340
def forward(self, x):
2441
y = self.silu_and_mul(x)
25-
x2 = scaled_fp8_quant(y, self.scale)
42+
x2 = self.fp8_linear.apply(y,
43+
self.w,
44+
self.wscale,
45+
input_scale=self.wscale)
2646
return x2
2747

2848

2949
@pytest.mark.parametrize("num_tokens", [256])
3050
@pytest.mark.parametrize("hidden_size", [64])
51+
@pytest.mark.parametrize("cutlass_fp8_enabled",
52+
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
3153
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
3254
reason="Only test on CUDA and ROCm")
33-
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
55+
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
56+
cutlass_fp8_enabled):
3457
torch.set_default_device("cuda")
3558
torch.set_default_dtype(torch.float16)
3659

@@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
4063
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
4164
fusion_pass = ActivationQuantFusionPass(config)
4265

43-
backend = TestBackend(fusion_pass)
44-
model = TestModel()
66+
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
67+
model = TestModel(hidden_size, cutlass_fp8_enabled)
4568

4669
# First dimension dynamic
47-
x = torch.rand(num_tokens, hidden_size)
70+
x = torch.rand(num_tokens, hidden_size * 2)
4871
torch._dynamo.mark_dynamic(x, 0)
4972

5073
result = model(x)

vllm/attention/backends/abstract.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import torch
1111

12+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
13+
GroupShape)
1214
from vllm.multimodal import MultiModalPlaceholderMap
1315

1416
if TYPE_CHECKING:
@@ -289,7 +291,7 @@ def forward(
289291
raise NotImplementedError
290292

291293
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
292-
group_shape: tuple[int, int]):
294+
group_shape: GroupShape):
293295
"""
294296
Does this attention implementation support fused output quantization.
295297
This is used by the AttnFusionPass to only fuse output quantization
@@ -298,7 +300,7 @@ def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
298300
TODO(luka) merge parameters into QuantDescriptor
299301
:param dtype: quantized dtype
300302
:param static: static or dynamic quantization
301-
:param group_shape: quant group shape. (-1, -1) for per-tensor.
303+
:param group_shape: quant group shape.
302304
:return: is fusion supported for this type of quantization
303305
"""
304306
return False

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
PagedAttentionMetadata)
2020
from vllm.config import get_current_vllm_config
2121
from vllm.logger import init_logger
22+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
23+
GroupShape)
2224
from vllm.platforms import current_platform
2325
from vllm.platforms.rocm import use_rocm_custom_paged_attention
2426

@@ -598,10 +600,10 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
598600
head_dim))
599601

600602
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
601-
group_shape: tuple[int, int]):
603+
group_shape: GroupShape):
602604
if self.use_triton_flash_attn:
603605
return dtype == current_platform.fp8_dtype(
604-
) and static and group_shape == (-1, -1) # per-tensor
606+
) and static and group_shape == GroupShape.PER_TENSOR
605607

606608
# Only supported in the Triton backend
607609
return False

vllm/compilation/fusion.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Callable, ClassVar, NamedTuple, Optional
3+
from typing import Callable, NamedTuple, Optional
44

55
import torch
66
import torch._inductor.pattern_matcher as pm
@@ -11,6 +11,8 @@
1111

1212
from vllm.config import VllmConfig
1313
from vllm.logger import init_logger
14+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15+
GroupShape)
1416
from vllm.platforms import current_platform
1517

1618
from .fx_utils import find_getitem_maybe
@@ -33,27 +35,6 @@ def empty_fp32(*args, **kwargs):
3335
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
3436

3537

36-
# Use proxy as NamedTuple direct subclasses cannot have static members
37-
class _GroupShape(NamedTuple):
38-
row: int
39-
col: int
40-
41-
42-
class GroupShape(_GroupShape):
43-
"""
44-
This class describes the quantization group shape.
45-
It includes static members for common shapes (per-tensor, per-token).
46-
"""
47-
48-
# Aliases for common quantization group shapes
49-
PER_TENSOR: ClassVar['GroupShape']
50-
PER_TOKEN: ClassVar['GroupShape']
51-
52-
53-
GroupShape.PER_TENSOR = GroupShape(-1, -1)
54-
GroupShape.PER_TOKEN = GroupShape(1, -1)
55-
56-
5738
class QuantKey(NamedTuple):
5839
"""
5940
Named tuple for identifying the type of quantization.

vllm/model_executor/layers/fused_moe/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def _fp8_quantize(
111111
is provided, the output will be blocked.
112112
"""
113113
if block_shape is None:
114+
# TODO(luka): use QuantFP8 custom op
115+
# https://github.com/vllm-project/vllm/issues/20711
114116
A, A_scale = ops.scaled_fp8_quant(
115117
A, A_scale, use_per_token_if_dynamic=per_act_token)
116118
else:

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
QKVParallelLinear)
1616
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1717
CompressedTensorsScheme)
18+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
19+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
20+
GroupShape)
1821
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1922
convert_to_channelwise, sparse_cutlass_supported)
2023
from vllm.model_executor.parameter import (BasevLLMParameter,
@@ -24,6 +27,8 @@
2427

2528
__all__ = ["CompressedTensors24"]
2629

30+
from vllm.platforms import current_platform
31+
2732

2833
class CompressedTensors24(CompressedTensorsScheme):
2934

@@ -45,6 +50,12 @@ def __init__(
4550
and self.model_compressor.sparsity_config.format
4651
== CompressionFormat.sparse_24_bitmask.value)
4752

53+
if quantized and input_quant is not None and \
54+
self._get_quant_dtype() == current_platform.fp8_dtype():
55+
static = not input_quant.dynamic
56+
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
57+
self.quant_fp8 = QuantFP8(static, g_shape)
58+
4859
@classmethod
4960
def get_min_capability(cls) -> int:
5061
# Only cutlass 3.x kernels are implemented so far
@@ -232,21 +243,15 @@ def apply_weights(
232243
:return: The output tensor of the layer
233244
"""
234245
if self.quantized:
235-
scale = None
236-
if hasattr(layer, "input_scale"):
237-
scale = layer.input_scale
246+
scale = getattr(layer, 'input_scale', None)
238247

239248
if self.weights_dtype == torch.int8:
240249
ops_output = ops.scaled_int8_quant(x, scale=scale)
241250
q_input = ops_output[0]
242251
input_scale = ops_output[1]
243252
else:
244253
assert self.weights_dtype == torch.float8_e4m3fn
245-
if scale is not None:
246-
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
247-
else:
248-
q_input, input_scale = ops.scaled_fp8_quant(
249-
x, use_per_token_if_dynamic=True)
254+
q_input, input_scale = self.quant_fp8(x, scale=scale)
250255

251256
else:
252257
# Not quantized, nothing to do with the input_scales, use as is
@@ -269,7 +274,10 @@ def apply_weights(
269274
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
270275
if not self.quantized:
271276
return params_dtype
277+
return self._get_quant_dtype()
272278

279+
def _get_quant_dtype(self) -> torch.dtype:
280+
assert self.quantized
273281
assert self.weight_quant is not None
274282
assert self.input_quant is not None
275283

0 commit comments

Comments
 (0)