diff --git a/run_fp4.py b/run_fp4.py new file mode 100644 index 00000000000..db40c5d0df9 --- /dev/null +++ b/run_fp4.py @@ -0,0 +1,24 @@ + +import numpy +import torch + +from vllm import LLM, SamplingParams + +prompts = [ + "The Swiss Alps are", "The president of the USA is", + "The Boston Bruins are" +] + +# Create a sampling params object for greedy sampling +sampling_params = SamplingParams(temperature=0.80, top_p=0.95, max_tokens=40, min_tokens=10) +#llm = LLM('nm-testing/Llama-3.1-8B-Instruct-FP4-Weight') +#llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/TinyLlama-1.1B-Chat-v1.0-FP4") +#llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-NVFP4A16") +#llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-NVFP4A16-MSE") +#llm = LLM("nm-testing/Llama-3.3-70B-Instruct-NVFP4A16", max_model_len=4096) +# Print the outputs. +llm = LLM("nvidia/Llama-3.3-70B-Instruct-FP4", max_model_len=4096, quantization="nvfp4", enforce_eager=True) +output = llm.generate(prompts, sampling_params) +for o in output: + print(o.outputs[0].text) + print("\n") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5be6b22c7b2..f6dc0d0d7d8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -24,7 +24,7 @@ W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, CompressedTensorsW4A4Fp4) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -299,6 +299,13 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, # All conditions satisfied. return True + def _is_fp4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): + is_group_quant = weight_quant.strategy == QuantizationStrategy.GROUP.value + is_group_size_16 = weight_quant.group_size == 16 + is_float_type = weight_quant.type == QuantizationType.FLOAT + is_4_bits = weight_quant.num_bits == 4 + return is_group_quant and is_float_type and is_4_bits + def _is_wNa16_group_channel(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: input_quant_none = input_quant is None @@ -313,6 +320,8 @@ def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": + if self._is_fp4_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A4Fp4() # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index b26c74f2484..67368cf0662 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -8,6 +8,7 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, CompressedTensorsWNA16) +from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_24 import CompressedTensors24 # isort: skip @@ -16,5 +17,5 @@ "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", - "CompressedTensors24" + "CompressedTensors24", "CompressedTensorsW4A4Fp4" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py new file mode 100644 index 00000000000..812c4430d68 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + dequantize_to_dtype) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensorsW4A4Fp4"] + + +class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): + + def __init__(self): + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + # dont restrict as emulations + return 80 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + # Weight + self.output_partition_sizes = output_partition_sizes + self.params_dtype = params_dtype + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 items are packed in the input dimension + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer) -> None: + print(layer.weight_global_scale) + layer.weight_global_scale = Parameter( + layer.weight_global_scale.max().to(torch.float32), + requires_grad=False) + # Note: a post weight loading step but not required for the emulation + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + w_fp4 = layer.weight_packed.data + w_global_scale = layer.weight_global_scale + w_blockscale = layer.weight_scale_swizzled.data + w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, + x.dtype, x.device, self.group_size) + out = F.linear(x, w_dq) + del w_dq + return out diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 8f1c8487746..9329985e5b3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union import torch +import torch._dynamo from torch.nn import Module from torch.nn.parameter import Parameter @@ -13,6 +14,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + dequantize_to_dtype, ref_nvfp4_quant) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -20,109 +23,14 @@ from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types + +torch._dynamo.config.suppress_errors = True logger = init_logger(__name__) QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] -FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() - -kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], - dtype=torch.float32) - - -def break_fp4_bytes(a, dtype): - assert a.dtype == torch.uint8 - m, n = a.shape - # Vectorized nibble processing - a_flat = a.flatten() - high = (a_flat & 0xF0) >> 4 # Upper nibbles - low = a_flat & 0x0F # Lower nibbles - # Combine nibbles for batch processing - combined = torch.stack((low, high), dim=1).flatten() - # Vectorized sign and magnitude extraction - signs = (combined & 0x08).to(torch.bool) # Sign bits - abs_vals = (combined & 0x07).to(torch.long) - # Device-aware lookup and sign application - kE2M1 = kE2M1ToFloat.to(device=a.device) - values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) - # Reshape to final form - return values.reshape(m, n * 2).to(dtype=dtype) - - -def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): - m_tiles = (m + 128 - 1) // 128 - f = block_size * 4 - k_tiles = (k + f - 1) // f - tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) - tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) - out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) - return out[0:m, 0:k] - - -def dequantize_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): - """Dequantize the fp4 tensor back to high precision.""" - # Two fp4 values are packed into one uint8. - assert tensor_fp4.dtype == torch.uint8 - m, packed_k = tensor_fp4.shape - k = packed_k * 2 - tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) - tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) - tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale - - # scale the tensor - out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) - return out.to(dtype) - - -def cast_to_fp4(x): - sign = torch.sign(x) - x = torch.abs(x) - x[(x >= 0.0) & (x <= 0.25)] = 0.0 - x[(x > 0.25) & (x < 0.75)] = 0.5 - x[(x >= 0.75) & (x <= 1.25)] = 1.0 - x[(x > 1.25) & (x < 1.75)] = 1.5 - x[(x >= 1.75) & (x <= 2.5)] = 2.0 - x[(x > 2.5) & (x < 3.5)] = 3.0 - x[(x >= 3.5) & (x <= 5.0)] = 4.0 - x[x > 5.0] = 6.0 - return x * sign - - -def get_reciprocal(x): - if isinstance(x, torch.Tensor): - return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) - elif isinstance(x, (float, int)): - return 0.0 if x == 0 else 1.0 / x - else: - raise TypeError("Input must be a float, int, or a torch.Tensor.") - - -def ref_nvfp4_quant(x, global_scale, block_size): - assert global_scale.dtype == torch.float32 - assert x.ndim == 2 - m, n = x.shape - x = torch.reshape(x, (m, n // block_size, block_size)) - vec_max = torch.max(torch.abs(x), dim=-1, - keepdim=True)[0].to(torch.float32) - scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) - scale = scale.to(torch.float8_e4m3fn).to(torch.float32) - output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) - - scaled_x = x.to(torch.float32) * output_scale - clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) - # both outputs are float32 - return cast_to_fp4(clipped_x), scale.squeeze(-1) - class ModelOptFp8Config(QuantizationConfig): """Config class for ModelOpt FP8.""" @@ -289,7 +197,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 89 + return 80 @classmethod def get_config_filenames(cls) -> List[str]: @@ -483,20 +391,11 @@ def apply( # for input only the contracting dimension has a constraint. x_m, x_k = x.shape - w_n, w_k = layer.weight.shape - # print(f"{x.shape=}") - # print(f"{layer.weight.shape=}") - output_shape = [x_m, w_n] - block_size = 16 + block_size = group_size = 16 # quantize input to (FP4 and interleaved block scale) - # x_global_scale = layer.input_scale x_global_scale = 1 / layer.input_scale - # x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, block_size) - # x_blockscale = self.swizzle_blockscale(x_blockscale) - # print(f"{x_fp4.shape=}") - # print(f"{x_blockscale.shape=}") # dequantize input x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size) @@ -507,20 +406,11 @@ def apply( # dequantize weight w_fp4 = layer.weight.data.view(torch.uint8) w_blockscale = layer.weight_scale_swizzled.data - w_global_scale = layer.weight_scale_2 - # print(f"{w_fp4.shape=}") - # print(f"{w_blockscale.shape=}") - # print(f"{w_global_scale.shape=}") + w_global_scale = 1 / layer.weight_scale_2 w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, - output_dtype, x.device, - block_size).to(output_dtype) - # print(f"{w_dq.shape=}") + output_dtype, x.device, block_size) # matmul out = torch.matmul(x_dq, w_dq.t()) - del x_dq, w_dq - # print(f"{out.shape=}") - - if bias is not None: - out = out + bias - return out.view(*output_shape) + del w_dq, x_dq + return out diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py new file mode 100644 index 00000000000..5ffe8fc8412 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.scalar_type import scalar_types + +__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype) + + +def dequantize_unfused(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype) + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def requantize_with_max(x, global_scale, scale): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + block_size = 16 + x = torch.reshape(x, (m, n // block_size, block_size)) + scale = scale.to(global_scale.dtype) / global_scale + scaled_x = x.to(torch.float32) / scale.unsqueeze(-1) + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + return cast_to_fp4(clipped_x) + + +def ref_nvfp4_quant(x, global_scale, block_size): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // block_size, block_size)) + vec_max = torch.max(torch.abs(x), dim=-1, + keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = torch.clamp(scale, max=448, min=-448) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + # both outputs are float32 + return cast_to_fp4(clipped_x), scale.squeeze(-1)