From 8072051a59d93339b77a441952487c81857f9044 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 1 May 2025 03:02:09 +0000 Subject: [PATCH 01/12] update Signed-off-by: Dipika Sikka --- .../layers/quantization/modelopt.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 8f1c8487746..664bcbbca11 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -73,11 +73,11 @@ def dequantize_to_dtype(tensor_fp4, assert tensor_fp4.dtype == torch.uint8 m, packed_k = tensor_fp4.shape k = packed_k * 2 - tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) tensor_sf = tensor_sf.view(torch.float8_e4m3fn) tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) - tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale # scale the tensor out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) @@ -488,7 +488,7 @@ def apply( # print(f"{layer.weight.shape=}") output_shape = [x_m, w_n] block_size = 16 - + """ # quantize input to (FP4 and interleaved block scale) # x_global_scale = layer.input_scale x_global_scale = 1 / layer.input_scale @@ -503,6 +503,7 @@ def apply( x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) del x_fp4, x_blockscale + """ # dequantize weight w_fp4 = layer.weight.data.view(torch.uint8) @@ -512,13 +513,11 @@ def apply( # print(f"{w_blockscale.shape=}") # print(f"{w_global_scale.shape=}") w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, - output_dtype, x.device, - block_size).to(output_dtype) - # print(f"{w_dq.shape=}") + output_dtype, x.device, block_size) # matmul - out = torch.matmul(x_dq, w_dq.t()) - del x_dq, w_dq + out = torch.matmul(x, w_dq.t()) + # del x_dq, w_dq # print(f"{out.shape=}") if bias is not None: From 3fc300ef4fcb4ce7a855781b592e161d101f2760 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 1 May 2025 13:56:32 -0400 Subject: [PATCH 02/12] move emulation utils into a shared files; add ct nvfp4 scheme --- .../schemes/compressed_tensors_w4a4_nvfp4.py | 88 +++++++++++++++ .../layers/quantization/modelopt.py | 103 +---------------- .../utils/nvfp4_emulation_utils.py | 104 ++++++++++++++++++ 3 files changed, 196 insertions(+), 99 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py create mode 100644 vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py new file mode 100644 index 00000000000..e4dde24bb07 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from torch.nn.parameter import Parameter +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParamete, GroupQuantScaleParameter) +from vllm.platforms import current_platform +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + dequantize_to_dtype +) + +__all__ = ["CompressedTensorsW4A4Fp4"] + + +class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): + def __init__(self): + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + # dont restrict as emulations + return 80 + + def process_weights_after_loading(self, layer) -> None: + weight_global_scale = layer.weight_global_scale.max().to(torch.float32) + layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + # Weight + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 items are packed in the input dimension + layer.output_size_per_partition, + layer.input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) + + # Input Scale + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition // group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + + w_fp4 = layer.weight_packed.data + w_blockscale = layer.weight_scale + w_global_scale = layer.weight_global_scale + w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, + x.dtype, x.device, self.group_size) + + return F.linear(x, w_dq) + diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 664bcbbca11..83ae1a31908 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -20,109 +20,14 @@ from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + ref_nvfp4_quant, dequantize_to_dtype) logger = init_logger(__name__) QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] -FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() - -kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], - dtype=torch.float32) - - -def break_fp4_bytes(a, dtype): - assert a.dtype == torch.uint8 - m, n = a.shape - # Vectorized nibble processing - a_flat = a.flatten() - high = (a_flat & 0xF0) >> 4 # Upper nibbles - low = a_flat & 0x0F # Lower nibbles - # Combine nibbles for batch processing - combined = torch.stack((low, high), dim=1).flatten() - # Vectorized sign and magnitude extraction - signs = (combined & 0x08).to(torch.bool) # Sign bits - abs_vals = (combined & 0x07).to(torch.long) - # Device-aware lookup and sign application - kE2M1 = kE2M1ToFloat.to(device=a.device) - values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) - # Reshape to final form - return values.reshape(m, n * 2).to(dtype=dtype) - - -def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): - m_tiles = (m + 128 - 1) // 128 - f = block_size * 4 - k_tiles = (k + f - 1) // f - tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) - tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) - out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) - return out[0:m, 0:k] - - -def dequantize_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): - """Dequantize the fp4 tensor back to high precision.""" - # Two fp4 values are packed into one uint8. - assert tensor_fp4.dtype == torch.uint8 - m, packed_k = tensor_fp4.shape - k = packed_k * 2 - tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) - tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) - tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale - - # scale the tensor - out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) - return out.to(dtype) - - -def cast_to_fp4(x): - sign = torch.sign(x) - x = torch.abs(x) - x[(x >= 0.0) & (x <= 0.25)] = 0.0 - x[(x > 0.25) & (x < 0.75)] = 0.5 - x[(x >= 0.75) & (x <= 1.25)] = 1.0 - x[(x > 1.25) & (x < 1.75)] = 1.5 - x[(x >= 1.75) & (x <= 2.5)] = 2.0 - x[(x > 2.5) & (x < 3.5)] = 3.0 - x[(x >= 3.5) & (x <= 5.0)] = 4.0 - x[x > 5.0] = 6.0 - return x * sign - - -def get_reciprocal(x): - if isinstance(x, torch.Tensor): - return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) - elif isinstance(x, (float, int)): - return 0.0 if x == 0 else 1.0 / x - else: - raise TypeError("Input must be a float, int, or a torch.Tensor.") - - -def ref_nvfp4_quant(x, global_scale, block_size): - assert global_scale.dtype == torch.float32 - assert x.ndim == 2 - m, n = x.shape - x = torch.reshape(x, (m, n // block_size, block_size)) - vec_max = torch.max(torch.abs(x), dim=-1, - keepdim=True)[0].to(torch.float32) - scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) - scale = scale.to(torch.float8_e4m3fn).to(torch.float32) - output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) - - scaled_x = x.to(torch.float32) * output_scale - clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) - # both outputs are float32 - return cast_to_fp4(clipped_x), scale.squeeze(-1) - class ModelOptFp8Config(QuantizationConfig): """Config class for ModelOpt FP8.""" @@ -289,7 +194,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 89 + return 80 @classmethod def get_config_filenames(cls) -> List[str]: @@ -508,7 +413,7 @@ def apply( # dequantize weight w_fp4 = layer.weight.data.view(torch.uint8) w_blockscale = layer.weight_scale_swizzled.data - w_global_scale = layer.weight_scale_2 + w_global_scale = 1 / layer.weight_scale_2 # print(f"{w_fp4.shape=}") # print(f"{w_blockscale.shape=}") # print(f"{w_global_scale.shape=}") diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py new file mode 100644 index 00000000000..c7600b2b745 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -0,0 +1,104 @@ +import torch +from vllm.scalar_type import scalar_types + +__all__ = [ + "break_fp4_bytes", + "dequantize_to_dtype", + "ref_nvfp4_quant" +] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype) + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def ref_nvfp4_quant(x, global_scale, block_size): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // block_size, block_size)) + vec_max = torch.max(torch.abs(x), dim=-1, + keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + # both outputs are float32 + return cast_to_fp4(clipped_x), scale.squeeze(-1) \ No newline at end of file From c99589aa1dccce97d8594f957ae70a61d44e4bf5 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 1 May 2025 15:39:49 -0400 Subject: [PATCH 03/12] update --- .../compressed_tensors/compressed_tensors.py | 11 ++++++++- .../compressed_tensors/schemes/__init__.py | 2 +- .../schemes/compressed_tensors_w4a4_nvfp4.py | 23 ++++++++----------- .../utils/nvfp4_emulation_utils.py | 4 ++-- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5be6b22c7b2..008dd885591 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -24,7 +24,7 @@ W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, CompressedTensorsW4A4Fp4) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -299,6 +299,13 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, # All conditions satisfied. return True + def _is_fp4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): + is_group_quant = weight_quant.strategy == QuantizationStrategy.GROUP.value + is_group_size_16 = weight_quant.group_size == 16 + is_float_type = weight_quant.type == QuantizationType.FLOAT + is_4_bits = weight_quant.num_bits == 8 + return is_group_quant and is_float_type and is_4_bits + def _is_wNa16_group_channel(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: input_quant_none = input_quant is None @@ -313,6 +320,8 @@ def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": + if self._is_fp4_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A4Fp4() # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index b26c74f2484..d8d3c44c110 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -16,5 +16,5 @@ "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", - "CompressedTensors24" + "CompressedTensors24", "CompressedTensorsW4A4Fp4" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index e4dde24bb07..30aa97a77fe 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -8,11 +8,12 @@ CompressedTensorsScheme) from torch.nn.parameter import Parameter from vllm.model_executor.parameter import (ModelWeightParameter, - PerTensorScaleParamete, GroupQuantScaleParameter) + PerTensorScaleParameter, GroupQuantScaleParameter) from vllm.platforms import current_platform from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( dequantize_to_dtype ) +import torch.nn.functional as F __all__ = ["CompressedTensorsW4A4Fp4"] @@ -25,11 +26,7 @@ def __init__(self): def get_min_capability(cls) -> int: # dont restrict as emulations return 80 - - def process_weights_after_loading(self, layer) -> None: - weight_global_scale = layer.weight_global_scale.max().to(torch.float32) - layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) - + def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size_per_partition: int, @@ -48,12 +45,6 @@ def create_weights(self, layer: torch.nn.Module, weight_loader=weight_loader) layer.register_parameter("weight_packed", weight) - # Input Scale - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - layer.register_parameter("input_scale", input_scale) - # Global Weight Scale weight_global_scale = PerTensorScaleParameter(data=torch.empty( len(output_partition_sizes), dtype=torch.float32), @@ -63,8 +54,8 @@ def create_weights(self, layer: torch.nn.Module, # Per Group Weight Scale weight_scale = GroupQuantScaleParameter(data=torch.empty( output_size_per_partition, - input_size_per_partition // group_size, - dtype=weight_dtype, + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, ), input_dim=1, output_dim=0, @@ -72,6 +63,10 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight_scale", weight_scale) + def process_weights_after_loading(self, layer) -> None: + weight_global_scale = layer.weight_global_scale.max().to(torch.float32) + layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index c7600b2b745..0493f6f2192 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -55,8 +55,8 @@ def dequantize_to_dtype(tensor_fp4, k = packed_k * 2 tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + #tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + #tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale # scale the tensor From 4104a49a6b7d11685324a7b66a712ebe8fe326b4 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 1 May 2025 16:30:43 -0400 Subject: [PATCH 04/12] fix condition, clean-up code --- .../quantization/compressed_tensors/compressed_tensors.py | 2 +- .../quantization/compressed_tensors/schemes/__init__.py | 1 + .../schemes/compressed_tensors_w4a4_nvfp4.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 008dd885591..f6dc0d0d7d8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -303,7 +303,7 @@ def _is_fp4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): is_group_quant = weight_quant.strategy == QuantizationStrategy.GROUP.value is_group_size_16 = weight_quant.group_size == 16 is_float_type = weight_quant.type == QuantizationType.FLOAT - is_4_bits = weight_quant.num_bits == 8 + is_4_bits = weight_quant.num_bits == 4 return is_group_quant and is_float_type and is_4_bits def _is_wNa16_group_channel(self, weight_quant: BaseModel, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index d8d3c44c110..67368cf0662 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -8,6 +8,7 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, CompressedTensorsWNA16) +from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_24 import CompressedTensors24 # isort: skip diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 30aa97a77fe..849db8b3ba0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -37,8 +37,8 @@ def create_weights(self, layer: torch.nn.Module, weight = ModelWeightParameter( data=torch.empty( # 2 fp4 items are packed in the input dimension - layer.output_size_per_partition, - layer.input_size_per_partition // 2, + sum(output_partition_sizes), + input_size_per_partition // 2, dtype=torch.uint8), input_dim=1, output_dim=0, @@ -53,7 +53,7 @@ def create_weights(self, layer: torch.nn.Module, # Per Group Weight Scale weight_scale = GroupQuantScaleParameter(data=torch.empty( - output_size_per_partition, + sum(output_partition_sizes), input_size_per_partition // self.group_size, dtype=torch.float8_e4m3fn, ), From 6e45f9980fb757720d2f55ec356c64f9db54411b Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 1 May 2025 16:30:54 -0400 Subject: [PATCH 05/12] update --- run_fp4.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 run_fp4.py diff --git a/run_fp4.py b/run_fp4.py new file mode 100644 index 00000000000..0e3639904d5 --- /dev/null +++ b/run_fp4.py @@ -0,0 +1,18 @@ +import numpy +import torch + +from vllm import LLM, SamplingParams + +prompts = ["The Swiss Alps are", "The president of the USA is", "The Boston Bruins are"] + +# Create a sampling params object for greedy sampling +sampling_params = SamplingParams(temperature=0.80, top_p=0.95, max_tokens=40, min_tokens=10) +llm = LLM('nm-testing/llama2.c-stories110M-FP4', enforce_eager=True) + + +# Print the outputs. +output = llm.generate(prompts, sampling_params) +for o in output: + print(o.outputs[0].text) + print("\n") + From 1940c9454a6aad8cf37772afe3d31adf50fdfdac Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 1 May 2025 17:09:25 -0400 Subject: [PATCH 06/12] swizzle --- .../schemes/compressed_tensors_w4a4_nvfp4.py | 35 +++++++++++++++++-- .../utils/nvfp4_emulation_utils.py | 4 +-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 849db8b3ba0..75d6b363cac 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -63,10 +63,39 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight_scale", weight_scale) + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + def process_weights_after_loading(self, layer) -> None: weight_global_scale = layer.weight_global_scale.max().to(torch.float32) layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) + # Note: a post weight loading step but not required for the emulation + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, @@ -74,10 +103,12 @@ def apply_weights(self, w_fp4 = layer.weight_packed.data - w_blockscale = layer.weight_scale w_global_scale = layer.weight_global_scale + w_blockscale = layer.weight_scale_swizzled.data w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, x.dtype, x.device, self.group_size) - return F.linear(x, w_dq) + out = F.linear(x, w_dq) + del w_dq + return out diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index 0493f6f2192..c7600b2b745 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -55,8 +55,8 @@ def dequantize_to_dtype(tensor_fp4, k = packed_k * 2 tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - #tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - #tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale # scale the tensor From 6412e5bad08d8ce8d339f3d08fc10dae06ff1e8b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 8 May 2025 19:02:29 +0000 Subject: [PATCH 07/12] add code to requantize with the max or expand the global scale to have the same shape as the local scale --- run_fp4.py | 8 ++- .../schemes/compressed_tensors_w4a4_nvfp4.py | 36 +++++++++--- .../utils/nvfp4_emulation_utils.py | 56 ++++++++++++++++++- 3 files changed, 88 insertions(+), 12 deletions(-) diff --git a/run_fp4.py b/run_fp4.py index 0e3639904d5..bccf767948d 100644 --- a/run_fp4.py +++ b/run_fp4.py @@ -6,10 +6,12 @@ prompts = ["The Swiss Alps are", "The president of the USA is", "The Boston Bruins are"] # Create a sampling params object for greedy sampling -sampling_params = SamplingParams(temperature=0.80, top_p=0.95, max_tokens=40, min_tokens=10) -llm = LLM('nm-testing/llama2.c-stories110M-FP4', enforce_eager=True) - +sampling_params = SamplingParams(max_tokens=40, min_tokens=10) +#llm = LLM('nm-testing/TinyLlama-1.1B-Chat-v1.0-FP4', enforce_eager=True) +llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-FP4", enforce_eager=True) +#llm = LLM("nvidia/Llama-3.3-70B-Instruct-FP4", quantization='nvfp4', max_model_len=2048, enforce_eager=True) +#llm = LLM("nm-testing/llama2.c-stories110M-FP4", enforce_eager=True) # Print the outputs. output = llm.generate(prompts, sampling_params) for o in output: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 75d6b363cac..39c89720957 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -11,7 +11,7 @@ PerTensorScaleParameter, GroupQuantScaleParameter) from vllm.platforms import current_platform from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( - dequantize_to_dtype + dequantize_to_dtype, dequantize_unfused, requantize_with_max, expand_global_scale ) import torch.nn.functional as F @@ -34,6 +34,8 @@ def create_weights(self, layer: torch.nn.Module, **kwargs): # Weight + self.output_partition_sizes = output_partition_sizes + self.params_dtype = params_dtype weight = ModelWeightParameter( data=torch.empty( # 2 fp4 items are packed in the input dimension @@ -87,12 +89,34 @@ def swizzle_blockscale(self, scale: torch.tensor): if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) def process_weights_after_loading(self, layer) -> None: - weight_global_scale = layer.weight_global_scale.max().to(torch.float32) - layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) + """ + end = 0 + all_dequant = [] + + + # dequant using global scale + for i in range(len(layer.weight_global_scale)): + global_scale = layer.weight_global_scale[i] + size = self.output_partition_sizes[i] + start = end + end = start + size + + dequant_weight = dequantize_unfused(layer.weight_packed[start:end, :], layer.weight_scale[start:end, :], global_scale, + self.params_dtype, layer.weight_packed.device, self.group_size) + + all_dequant.append(dequant_weight) + + del layer.weight_packed + layer.weight_global_scale = Parameter(layer.weight_global_scale.max().to(torch.float32), requires_grad=False) + # requant with max + layer.weight = requantize_with_max(torch.cat(all_dequant, axis=0), layer.weight_global_scale, layer.weight_scale) + """ + + # expand global_scale to have the same shape as the weight scales + layer.weight_global_scale = Parameter(expand_global_scale(layer.weight_scale, self.output_partition_sizes, layer.weight_global_scale), requires_grad=False) # Note: a post weight loading step but not required for the emulation swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, requires_grad=False) @@ -101,13 +125,11 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - w_fp4 = layer.weight_packed.data w_global_scale = layer.weight_global_scale w_blockscale = layer.weight_scale_swizzled.data w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, - x.dtype, x.device, self.group_size) - + x.dtype, x.device, self.group_size) out = F.linear(x, w_dq) del w_dq return out diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index c7600b2b745..9dafa9752cf 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -4,7 +4,10 @@ __all__ = [ "break_fp4_bytes", "dequantize_to_dtype", - "ref_nvfp4_quant" + "ref_nvfp4_quant", + "dequantize_unfused", + "requantize_with_max", + "expand_global_scale" ] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() @@ -53,8 +56,10 @@ def dequantize_to_dtype(tensor_fp4, assert tensor_fp4.dtype == torch.uint8 m, packed_k = tensor_fp4.shape k = packed_k * 2 + #m, k = tensor_fp4.shape tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + #tensor_f32 = tensor_fp4.reshape(m, k // block_size, block_size) tensor_sf = tensor_sf.view(torch.float8_e4m3fn) tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale @@ -64,6 +69,26 @@ def dequantize_to_dtype(tensor_fp4, return out.to(dtype) +def dequantize_unfused(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype) + + def cast_to_fp4(x): sign = torch.sign(x) x = torch.abs(x) @@ -87,6 +112,18 @@ def get_reciprocal(x): raise TypeError("Input must be a float, int, or a torch.Tensor.") +def requantize_with_max(x, global_scale, scale): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + block_size = 16 + x = torch.reshape(x, (m, n // block_size, block_size)) + scale = scale.to(global_scale.dtype) / global_scale + scaled_x = x.to(torch.float32) / scale.unsqueeze(-1) + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + return cast_to_fp4(clipped_x) + + def ref_nvfp4_quant(x, global_scale, block_size): assert global_scale.dtype == torch.float32 assert x.ndim == 2 @@ -101,4 +138,19 @@ def ref_nvfp4_quant(x, global_scale, block_size): scaled_x = x.to(torch.float32) * output_scale clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) # both outputs are float32 - return cast_to_fp4(clipped_x), scale.squeeze(-1) \ No newline at end of file + return cast_to_fp4(clipped_x), scale.squeeze(-1) + + +def expand_global_scale(local_scale, sizes, global_scale): + output_scale = torch.zeros(local_scale.shape, dtype=global_scale.dtype).to(global_scale.device) + end = 0 + for i in range(len(sizes)): + size = sizes[i] + current_global_scale = global_scale[i] + start = end + end = start + size + output_scale[start:end, :] = current_global_scale + + print(output_scale) + print(global_scale) + return output_scale \ No newline at end of file From 4f11acbc19691eb5dbaf0a516aa666fe1f125b7a Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 9 May 2025 16:31:48 +0000 Subject: [PATCH 08/12] update --- run_fp4.py | 20 +++--- .../schemes/compressed_tensors_w4a4_nvfp4.py | 63 ++++++------------- .../utils/nvfp4_emulation_utils.py | 32 ++++------ 3 files changed, 44 insertions(+), 71 deletions(-) diff --git a/run_fp4.py b/run_fp4.py index bccf767948d..25df88748aa 100644 --- a/run_fp4.py +++ b/run_fp4.py @@ -1,20 +1,24 @@ -import numpy -import torch +# SPDX-License-Identifier: Apache-2.0 from vllm import LLM, SamplingParams -prompts = ["The Swiss Alps are", "The president of the USA is", "The Boston Bruins are"] +prompts = [ + "The Swiss Alps are", + "The president of the USA is", + "The Boston Bruins are" +] # Create a sampling params object for greedy sampling -sampling_params = SamplingParams(max_tokens=40, min_tokens=10) -#llm = LLM('nm-testing/TinyLlama-1.1B-Chat-v1.0-FP4', enforce_eager=True) -llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-FP4", enforce_eager=True) -#llm = LLM("nvidia/Llama-3.3-70B-Instruct-FP4", quantization='nvfp4', max_model_len=2048, enforce_eager=True) +sampling_params = SamplingParams(temperature=0.90, max_tokens=40, min_tokens=10) +#llm = LLM('/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/TinyLlama-1.1B-Chat-v1.0-FP4', enforce_eager=True) +llm = LLM( + "/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-FP4", + enforce_eager=True) +#llm = LLM("nvidia/Llama-3.3-70B-Instruct-FP4", quantization='nvfp4', max_model_len=2048, enforce_eager=True) #llm = LLM("nm-testing/llama2.c-stories110M-FP4", enforce_eager=True) # Print the outputs. output = llm.generate(prompts, sampling_params) for o in output: print(o.outputs[0].text) print("\n") - diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 39c89720957..6007e71c60e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -3,36 +3,36 @@ from typing import Callable, List, Optional import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from torch.nn.parameter import Parameter -from vllm.model_executor.parameter import (ModelWeightParameter, - PerTensorScaleParameter, GroupQuantScaleParameter) -from vllm.platforms import current_platform from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( - dequantize_to_dtype, dequantize_unfused, requantize_with_max, expand_global_scale -) -import torch.nn.functional as F + dequantize_to_dtype, expand_global_scale) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) __all__ = ["CompressedTensorsW4A4Fp4"] class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): + def __init__(self): - self.group_size = 16 + self.group_size = 16 @classmethod def get_min_capability(cls) -> int: # dont restrict as emulations return 80 - + def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - + # Weight self.output_partition_sizes = output_partition_sizes self.params_dtype = params_dtype @@ -48,9 +48,9 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight_packed", weight) # Global Weight Scale - weight_global_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) layer.register_parameter("weight_global_scale", weight_global_scale) # Per Group Weight Scale @@ -59,9 +59,9 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition // self.group_size, dtype=torch.float8_e4m3fn, ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight_scale", weight_scale) @@ -89,32 +89,8 @@ def swizzle_blockscale(self, scale: torch.tensor): if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) def process_weights_after_loading(self, layer) -> None: - """ - end = 0 - all_dequant = [] - - - # dequant using global scale - for i in range(len(layer.weight_global_scale)): - global_scale = layer.weight_global_scale[i] - size = self.output_partition_sizes[i] - start = end - end = start + size - - dequant_weight = dequantize_unfused(layer.weight_packed[start:end, :], layer.weight_scale[start:end, :], global_scale, - self.params_dtype, layer.weight_packed.device, self.group_size) - - all_dequant.append(dequant_weight) - - del layer.weight_packed + print(layer.weight_global_scale) layer.weight_global_scale = Parameter(layer.weight_global_scale.max().to(torch.float32), requires_grad=False) - # requant with max - layer.weight = requantize_with_max(torch.cat(all_dequant, axis=0), layer.weight_global_scale, layer.weight_scale) - """ - - # expand global_scale to have the same shape as the weight scales - layer.weight_global_scale = Parameter(expand_global_scale(layer.weight_scale, self.output_partition_sizes, layer.weight_global_scale), requires_grad=False) - # Note: a post weight loading step but not required for the emulation swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, @@ -129,8 +105,7 @@ def apply_weights(self, w_global_scale = layer.weight_global_scale w_blockscale = layer.weight_scale_swizzled.data w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, - x.dtype, x.device, self.group_size) + x.dtype, x.device, self.group_size) out = F.linear(x, w_dq) - del w_dq + del w_dq return out - diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index 9dafa9752cf..518841b6f63 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -1,13 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 import torch + from vllm.scalar_type import scalar_types __all__ = [ - "break_fp4_bytes", - "dequantize_to_dtype", - "ref_nvfp4_quant", - "dequantize_unfused", - "requantize_with_max", - "expand_global_scale" + "break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant", + "dequantize_unfused", "requantize_with_max", "expand_global_scale" ] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() @@ -56,10 +54,8 @@ def dequantize_to_dtype(tensor_fp4, assert tensor_fp4.dtype == torch.uint8 m, packed_k = tensor_fp4.shape k = packed_k * 2 - #m, k = tensor_fp4.shape tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - #tensor_f32 = tensor_fp4.reshape(m, k // block_size, block_size) tensor_sf = tensor_sf.view(torch.float8_e4m3fn) tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale @@ -68,13 +64,12 @@ def dequantize_to_dtype(tensor_fp4, out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) return out.to(dtype) - def dequantize_unfused(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): + tensor_sf, + global_scale, + dtype, + device, + block_size=16): assert tensor_fp4.dtype == torch.uint8 m, packed_k = tensor_fp4.shape @@ -142,15 +137,14 @@ def ref_nvfp4_quant(x, global_scale, block_size): def expand_global_scale(local_scale, sizes, global_scale): - output_scale = torch.zeros(local_scale.shape, dtype=global_scale.dtype).to(global_scale.device) + output_scale = torch.zeros(local_scale.shape, dtype=global_scale.dtype).to( + global_scale.device) end = 0 for i in range(len(sizes)): size = sizes[i] current_global_scale = global_scale[i] - start = end + start = end end = start + size output_scale[start:end, :] = current_global_scale - print(output_scale) - print(global_scale) - return output_scale \ No newline at end of file + return output_scale From 802b6afa27ee531f3ce8029bcf6813e1bc63a45c Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 9 May 2025 16:32:41 +0000 Subject: [PATCH 09/12] update --- run_fp4.py | 7 ++++--- .../schemes/compressed_tensors_w4a4_nvfp4.py | 6 ++++-- .../layers/quantization/utils/nvfp4_emulation_utils.py | 1 + 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/run_fp4.py b/run_fp4.py index 25df88748aa..cd72afbb1be 100644 --- a/run_fp4.py +++ b/run_fp4.py @@ -3,13 +3,14 @@ from vllm import LLM, SamplingParams prompts = [ - "The Swiss Alps are", - "The president of the USA is", + "The Swiss Alps are", "The president of the USA is", "The Boston Bruins are" ] # Create a sampling params object for greedy sampling -sampling_params = SamplingParams(temperature=0.90, max_tokens=40, min_tokens=10) +sampling_params = SamplingParams(temperature=0.90, + max_tokens=40, + min_tokens=10) #llm = LLM('/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/TinyLlama-1.1B-Chat-v1.0-FP4', enforce_eager=True) llm = LLM( "/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-FP4", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 6007e71c60e..812c4430d68 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( - dequantize_to_dtype, expand_global_scale) + dequantize_to_dtype) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -90,7 +90,9 @@ def swizzle_blockscale(self, scale: torch.tensor): def process_weights_after_loading(self, layer) -> None: print(layer.weight_global_scale) - layer.weight_global_scale = Parameter(layer.weight_global_scale.max().to(torch.float32), requires_grad=False) + layer.weight_global_scale = Parameter( + layer.weight_global_scale.max().to(torch.float32), + requires_grad=False) # Note: a post weight loading step but not required for the emulation swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index 518841b6f63..cab0ea204c0 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -64,6 +64,7 @@ def dequantize_to_dtype(tensor_fp4, out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) return out.to(dtype) + def dequantize_unfused(tensor_fp4, tensor_sf, global_scale, From 8a38a88b247cc1f6e4cec2fbc339938f776dab03 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 14 May 2025 20:30:47 -0400 Subject: [PATCH 10/12] update emulation --- .../layers/quantization/modelopt.py | 33 ++++++------------- .../utils/nvfp4_emulation_utils.py | 1 + 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 83ae1a31908..284c17cef5e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -22,6 +22,9 @@ from vllm.platforms import current_platform from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( ref_nvfp4_quant, dequantize_to_dtype) +from math import ceil +import torch._dynamo +torch._dynamo.config.suppress_errors = True logger = init_logger(__name__) @@ -388,43 +391,27 @@ def apply( # for input only the contracting dimension has a constraint. x_m, x_k = x.shape - w_n, w_k = layer.weight.shape - # print(f"{x.shape=}") - # print(f"{layer.weight.shape=}") - output_shape = [x_m, w_n] - block_size = 16 - """ + block_size = group_size = 16 + # quantize input to (FP4 and interleaved block scale) - # x_global_scale = layer.input_scale x_global_scale = 1 / layer.input_scale - # x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, block_size) - # x_blockscale = self.swizzle_blockscale(x_blockscale) - # print(f"{x_fp4.shape=}") - # print(f"{x_blockscale.shape=}") # dequantize input x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size) x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) del x_fp4, x_blockscale - """ - + # dequantize weight w_fp4 = layer.weight.data.view(torch.uint8) w_blockscale = layer.weight_scale_swizzled.data w_global_scale = 1 / layer.weight_scale_2 - # print(f"{w_fp4.shape=}") - # print(f"{w_blockscale.shape=}") - # print(f"{w_global_scale.shape=}") w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, output_dtype, x.device, block_size) # matmul - out = torch.matmul(x, w_dq.t()) - # del x_dq, w_dq - # print(f"{out.shape=}") - - if bias is not None: - out = out + bias - return out.view(*output_shape) + out = torch.matmul(x_dq, w_dq.t()) + del w_dq, x_dq + return out + diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index cab0ea204c0..4a3f1df06ca 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -128,6 +128,7 @@ def ref_nvfp4_quant(x, global_scale, block_size): vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = torch.clamp(scale, max=448, min=-448) scale = scale.to(torch.float8_e4m3fn).to(torch.float32) output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) From f087703c64700f97b700b5d0ee4a5846bea8f2b2 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 14 May 2025 20:35:25 -0400 Subject: [PATCH 11/12] update emulation --- .../layers/quantization/modelopt.py | 13 ++++++------- .../utils/nvfp4_emulation_utils.py | 19 +------------------ 2 files changed, 7 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 284c17cef5e..9329985e5b3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union import torch +import torch._dynamo from torch.nn import Module from torch.nn.parameter import Parameter @@ -13,6 +14,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + dequantize_to_dtype, ref_nvfp4_quant) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -20,10 +23,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.platforms import current_platform -from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( - ref_nvfp4_quant, dequantize_to_dtype) -from math import ceil -import torch._dynamo + torch._dynamo.config.suppress_errors = True logger = init_logger(__name__) @@ -392,7 +392,7 @@ def apply( # for input only the contracting dimension has a constraint. x_m, x_k = x.shape block_size = group_size = 16 - + # quantize input to (FP4 and interleaved block scale) x_global_scale = 1 / layer.input_scale x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, block_size) @@ -402,7 +402,7 @@ def apply( x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) del x_fp4, x_blockscale - + # dequantize weight w_fp4 = layer.weight.data.view(torch.uint8) w_blockscale = layer.weight_scale_swizzled.data @@ -414,4 +414,3 @@ def apply( out = torch.matmul(x_dq, w_dq.t()) del w_dq, x_dq return out - diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index 4a3f1df06ca..5ffe8fc8412 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -3,10 +3,7 @@ from vllm.scalar_type import scalar_types -__all__ = [ - "break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant", - "dequantize_unfused", "requantize_with_max", "expand_global_scale" -] +__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() @@ -136,17 +133,3 @@ def ref_nvfp4_quant(x, global_scale, block_size): clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) # both outputs are float32 return cast_to_fp4(clipped_x), scale.squeeze(-1) - - -def expand_global_scale(local_scale, sizes, global_scale): - output_scale = torch.zeros(local_scale.shape, dtype=global_scale.dtype).to( - global_scale.device) - end = 0 - for i in range(len(sizes)): - size = sizes[i] - current_global_scale = global_scale[i] - start = end - end = start + size - output_scale[start:end, :] = current_global_scale - - return output_scale From f679c6164a1510151aec38483e3f879fa15bebf0 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 14 May 2025 20:36:08 -0400 Subject: [PATCH 12/12] add script --- run_fp4.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/run_fp4.py b/run_fp4.py index cd72afbb1be..db40c5d0df9 100644 --- a/run_fp4.py +++ b/run_fp4.py @@ -1,4 +1,6 @@ -# SPDX-License-Identifier: Apache-2.0 + +import numpy +import torch from vllm import LLM, SamplingParams @@ -8,17 +10,14 @@ ] # Create a sampling params object for greedy sampling -sampling_params = SamplingParams(temperature=0.90, - max_tokens=40, - min_tokens=10) -#llm = LLM('/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/TinyLlama-1.1B-Chat-v1.0-FP4', enforce_eager=True) -llm = LLM( - "/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-FP4", - enforce_eager=True) - -#llm = LLM("nvidia/Llama-3.3-70B-Instruct-FP4", quantization='nvfp4', max_model_len=2048, enforce_eager=True) -#llm = LLM("nm-testing/llama2.c-stories110M-FP4", enforce_eager=True) +sampling_params = SamplingParams(temperature=0.80, top_p=0.95, max_tokens=40, min_tokens=10) +#llm = LLM('nm-testing/Llama-3.1-8B-Instruct-FP4-Weight') +#llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/TinyLlama-1.1B-Chat-v1.0-FP4") +#llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-NVFP4A16") +#llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-NVFP4A16-MSE") +#llm = LLM("nm-testing/Llama-3.3-70B-Instruct-NVFP4A16", max_model_len=4096) # Print the outputs. +llm = LLM("nvidia/Llama-3.3-70B-Instruct-FP4", max_model_len=4096, quantization="nvfp4", enforce_eager=True) output = llm.generate(prompts, sampling_params) for o in output: print(o.outputs[0].text)