Skip to content

Commit ab2555e

Browse files
dsikkajimpang
authored andcommitted
[Quantization] Remove FP4 emulation; Fall-back to marlin for device < 100 (vllm-project#19563)
1 parent 7044280 commit ab2555e

File tree

5 files changed

+79
-60
lines changed

5 files changed

+79
-60
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,13 @@ def check_model(model):
667667
qkv_proj = layer.self_attn.qkv_proj
668668
assert isinstance(qkv_proj.quant_method,
669669
CompressedTensorsLinearMethod)
670-
assert isinstance(qkv_proj.scheme, scheme)
670+
if isinstance(qkv_proj.scheme, scheme) or isinstance(
671+
qkv_proj.scheme, CompressedTensorsW4A16Fp4
672+
) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
673+
assert True
674+
else:
675+
raise AssertionError("FP4 Scheme Mismatch")
676+
671677
assert qkv_proj.scheme.group_size == 16
672678

673679
llm.apply_model(check_model)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,14 @@ def _get_scheme_from_parts(
374374

375375
if is_activation_quantization_format(self.quant_format):
376376
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
377-
return CompressedTensorsW4A4Fp4()
377+
if CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
378+
return CompressedTensorsW4A4Fp4()
379+
else:
380+
logger.warning_once(
381+
"Current platform does not support cutlass NVFP4."
382+
" Running CompressedTensorsW4A16Fp4.")
383+
return CompressedTensorsW4A16Fp4(
384+
has_input_global_scale=True)
378385

379386
if self._is_fp8_w8a8(weight_quant, input_quant):
380387
is_fp8_w8a8_supported = self._check_scheme_supported(

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
2020

21-
def __init__(self):
21+
def __init__(self, has_input_global_scale: bool = False):
22+
self.has_input_global_scale = has_input_global_scale
2223
self.group_size = 16
2324

2425
@classmethod
@@ -64,6 +65,13 @@ def create_weights(self, layer: torch.nn.Module,
6465

6566
layer.register_parameter("weight_scale", weight_scale)
6667

68+
if self.has_input_global_scale:
69+
input_global_scale = PerTensorScaleParameter(
70+
data=torch.empty(len(output_partition_sizes),
71+
dtype=torch.float32),
72+
weight_loader=weight_loader)
73+
layer.register_parameter("input_global_scale", input_global_scale)
74+
6775
def process_weights_after_loading(self, layer) -> None:
6876
# Process parameters for marlin repacking
6977

@@ -77,6 +85,10 @@ def process_weights_after_loading(self, layer) -> None:
7785
requires_grad=False)
7886
del layer.weight_global_scale
7987

88+
if self.has_input_global_scale:
89+
layer.input_global_scale = torch.nn.Parameter(
90+
layer.input_global_scale.data, requires_grad=False)
91+
8092
prepare_fp4_layer_for_marlin(layer)
8193

8294
def apply_weights(self,

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 22 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from vllm.logger import init_logger
1010
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1111
CompressedTensorsScheme)
12-
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
13-
dequantize_to_dtype, ref_nvfp4_quant)
1412
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
1513
ModelWeightParameter,
1614
PerTensorScaleParameter)
@@ -21,53 +19,23 @@
2119
__all__ = ["CompressedTensorsW4A4Fp4"]
2220

2321

24-
def cutlass_fp4_supported() -> bool:
25-
if not current_platform.is_cuda():
26-
return False
27-
capability_tuple = current_platform.get_device_capability()
28-
capability = -1 if capability_tuple is None else capability_tuple.to_int()
29-
return cutlass_scaled_mm_supports_fp4(capability)
30-
31-
3222
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
3323

3424
def __init__(self):
3525
self.group_size = 16
36-
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
37-
if not self.cutlass_nvfp4_supported:
38-
logger.warning("Current platform does not support cutlass NVFP4."
39-
" Running emulations.")
4026

4127
@classmethod
4228
def get_min_capability(cls) -> int:
43-
# dont restrict as emulations
44-
return 80
45-
46-
def run_nvfp4_emulations(self, x: torch.Tensor, layer):
47-
x_m, x_k = x.shape
48-
output_dtype = x.dtype
49-
50-
# quantize input to (FP4 and interleaved block scale)
51-
x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale,
52-
self.group_size)
29+
return 100
5330

54-
# dequantize input
55-
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
56-
x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale
57-
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
58-
del x_fp4, x_blockscale
59-
60-
# dequantize weight
61-
w_fp4 = layer.weight.data.view(torch.uint8)
62-
w_blockscale = layer.weight_scale_swizzled.data
63-
w_global_scale = layer.weight_global_scale
64-
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
65-
output_dtype, x.device, self.group_size)
66-
67-
# matmul
68-
out = torch.matmul(x_dq, w_dq.t())
69-
del w_dq, x_dq
70-
return out
31+
@classmethod
32+
def cutlass_fp4_supported(cls) -> bool:
33+
if not current_platform.is_cuda():
34+
return False
35+
capability_tuple = current_platform.get_device_capability()
36+
capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501
37+
)
38+
return cutlass_scaled_mm_supports_fp4(capability)
7139

7240
def create_weights(self, layer: torch.nn.Module,
7341
output_partition_sizes: list[int],
@@ -152,27 +120,24 @@ def process_weights_after_loading(self, layer) -> None:
152120
# required by cutlass kernel; need Parameter, not ModelWeightParameter
153121
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
154122

155-
if self.cutlass_nvfp4_supported:
156-
layer.alpha = Parameter(layer.input_global_scale *
157-
layer.weight_global_scale,
158-
requires_grad=False)
123+
layer.alpha = Parameter(layer.input_global_scale *
124+
layer.weight_global_scale,
125+
requires_grad=False)
159126

160127
def apply_weights(self,
161128
layer: torch.nn.Module,
162129
x: torch.Tensor,
163130
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
164131

165-
if self.cutlass_nvfp4_supported:
166-
output_dtype = x.dtype
167-
output_shape = [x.shape[0], layer.weight.shape[0]]
132+
output_dtype = x.dtype
133+
output_shape = [x.shape[0], layer.weight.shape[0]]
168134

169-
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
170-
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
135+
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
136+
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
171137

172-
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
173-
layer.weight_scale_swizzled,
174-
1 / layer.alpha, output_dtype)
175-
if bias is not None:
176-
out = out + bias
177-
return out.view(*output_shape)
178-
return self.run_nvfp4_emulations(x, layer)
138+
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
139+
layer.weight_scale_swizzled,
140+
1 / layer.alpha, output_dtype)
141+
if bias is not None:
142+
out = out + bias
143+
return out.view(*output_shape)

vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,32 @@ def ref_nvfp4_quant(x, global_scale, block_size):
102102
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
103103
# both outputs are float32
104104
return cast_to_fp4(clipped_x), scale.squeeze(-1)
105+
106+
107+
def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor,
108+
weight: torch.Tensor,
109+
weight_scale_swizzled: torch.Tensor,
110+
weight_global_scale: torch.Tensor):
111+
group_size = 16
112+
x_m, x_k = x.shape
113+
output_dtype = x.dtype
114+
115+
# quantize input to (FP4 and interleaved block scale)
116+
x_fp4, x_blockscale = ref_nvfp4_quant(x, input_global_scale, group_size)
117+
118+
# dequantize input
119+
x_fp4 = x_fp4.reshape(x_m, x_k // group_size, group_size)
120+
x_blockscale = x_blockscale.unsqueeze(-1) / input_global_scale
121+
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
122+
del x_fp4, x_blockscale
123+
124+
# dequantize weight
125+
w_fp4 = weight.data.view(torch.uint8)
126+
w_dq = dequantize_to_dtype(w_fp4, weight_scale_swizzled.data,
127+
weight_global_scale, output_dtype, x.device,
128+
group_size)
129+
130+
# matmul
131+
out = torch.matmul(x_dq, w_dq.t())
132+
del w_dq, x_dq
133+
return out

0 commit comments

Comments
 (0)