diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 588aa8deb18..bcc68d08bf0 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -759,12 +759,10 @@ def weight_loader_v2(self, tp_size = get_tensor_model_parallel_world_size() if isinstance(param, BlockQuantScaleParameter): - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8LinearMethod, Fp8MoEMethod) assert self.quant_method is not None - assert isinstance(self.quant_method, - (Fp8LinearMethod, Fp8MoEMethod)) - weight_block_size = self.quant_method.quant_config.weight_block_size + # Assume the weight block size has been set by quant method + assert hasattr(self, "weight_block_size") + weight_block_size = self.weight_block_size assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = ( @@ -934,8 +932,10 @@ def weight_loader_v2(self, # Note(simon): This is needed for Qwen3's fp8 quantization. if isinstance(param, BlockQuantScaleParameter): assert self.quant_method is not None - assert hasattr(self.quant_method, "quant_config") - weight_block_size = self.quant_method.quant_config.weight_block_size + # Assume the weight block size has been set by the quant method + assert hasattr(self, "weight_block_size") + weight_block_size = self.weight_block_size + assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = (shard_offset + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4f87b2a44f0..02a4dd562da 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -11,7 +11,6 @@ from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) -from pydantic import BaseModel import vllm.envs as envs from vllm.logger import init_logger @@ -221,7 +220,8 @@ def _check_scheme_supported(self, else: return False - def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): + def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs): if weight_quant is None or input_quant is None: return False @@ -241,8 +241,8 @@ def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): return (is_tensor_group_quant and is_float_type and is_4_bits and is_group_size_16 and is_symmetric) - def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, - input_quant: BaseModel): + def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs): is_weight_only = weight_quant is not None and input_quant is None is_tensor_group_quant = ( @@ -256,8 +256,8 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, return (is_weight_only and is_tensor_group_quant and is_float_type and is_4_bits and is_group_size_16 and is_symmetric) - def _is_static_tensor_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value @@ -270,8 +270,8 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel, # Only symmetric weight quantization supported. return is_8_bits and is_tensor and weight_quant.symmetric and is_static - def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value @@ -284,8 +284,8 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, # Only symmetric weight quantization supported. return is_8_bits and is_token and weight_quant.symmetric and is_dynamic - def _is_fp8_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: # Confirm weights and activations quantized. if weight_quant is None or input_quant is None: return False @@ -295,11 +295,12 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, and input_quant.type == QuantizationType.FLOAT) is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK ]) if not (is_floating_point and is_symmetric_weight and is_static_weight - and is_per_tensor_or_channel_weight): + and is_tensor_or_channel_or_block_weight): return False # Dynamic quantization is always supported if weights supported. @@ -312,13 +313,16 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation - def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: + # Block quantization is not supported for SM90 CUTLASS yet + is_block_quant = weight_quant.strategy == QuantizationStrategy.BLOCK return (self._check_scheme_supported(90, error=False, match_exact=True) - and self._is_fp8_w8a8(weight_quant, input_quant)) + and self._is_fp8_w8a8(weight_quant, input_quant) + and not is_block_quant) - def _is_fp8_w8a16(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a16(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: # Confirm weights quantized. if weight_quant is None: return False @@ -328,20 +332,22 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, return False # Confirm weight scheme is supported. + is_floating_point = weight_quant.type == QuantizationType.FLOAT is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK ]) - if not (is_symmetric_weight and is_static_weight # noqa: SIM103 - and is_per_tensor_or_channel_weight): + if not (is_floating_point and is_symmetric_weight # noqa: SIM103 + and is_static_weight and is_tensor_or_channel_or_block_weight): return False # All conditions satisfied. return True - def _is_wNa16_group_channel(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: input_quant_none = input_quant is None is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value @@ -351,8 +357,8 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel, return (is_channel_group and input_quant_none and is_static) def _get_scheme_from_parts( - self, weight_quant: BaseModel, - input_quant: BaseModel) -> "CompressedTensorsScheme": + self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> "CompressedTensorsScheme": # Detect If Mixed Precision if self._is_fp4a16_nvfp4(weight_quant, input_quant): @@ -392,7 +398,7 @@ def _get_scheme_from_parts( CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( - strategy=weight_quant.strategy, + weight_quant=weight_quant, is_static_input_scheme=(input_quant and not input_quant.dynamic)) else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fa4ce566809..bdc69bca504 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum +import functools from enum import Enum from typing import Callable, Optional @@ -12,6 +13,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) @@ -32,7 +34,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.utils import has_pplx +from vllm.utils import has_deep_gemm, has_pplx if current_platform.is_cuda_alike(): from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( @@ -362,6 +364,12 @@ def apply( device=x.device).to(x.dtype) +def _is_col_major(x: torch.Tensor) -> bool: + assert x.dim() == 3 + b, m, n = x.shape + return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m + + class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( @@ -380,17 +388,21 @@ def __init__( per_channel = ( self.weight_quant.strategy == QuantizationStrategy.CHANNEL and self.input_quant.strategy == QuantizationStrategy.TOKEN) - if not (per_tensor or per_channel): + block_quant = (self.weight_quant.strategy == QuantizationStrategy.BLOCK + and self.input_quant.strategy + == QuantizationStrategy.GROUP + and self.weight_quant.block_structure is not None) + if not (per_tensor or per_channel or block_quant): raise ValueError( - "For FP8 Fused MoE layers, we require per tensor " - "or channelwise, dynamic per token quantization. Found " + "For FP8 Fused MoE layers, we require tensor/tensor, " + "channel/token, or block/group quantization. Found " f"{self.weight_quant}, {self.input_quant}") self.static_input_scales = not self.input_quant.dynamic - if self.static_input_scales and per_channel: + if self.static_input_scales and not per_tensor: raise ValueError( - "For FP8 Fused MoE layer, we require either per tensor or " - "channelwise, dynamic per token quantization.") + "For FP8 Fused MoE layer with static input scales, we require " + "either per tensor quantization.") # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization @@ -404,6 +416,33 @@ def __init__( self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + if self.weight_quant.block_structure is None: + block_size = None + elif isinstance(self.weight_quant.block_structure, str): + block_size = [ + int(x) for x in self.weight_quant.block_structure.split("x") + ] + else: + block_size = self.weight_quant.block_structure + self.weight_block_size = block_size + self.block_quant = block_quant + + # Check for DeepGemm support. + self.allow_deep_gemm = False + if envs.VLLM_USE_DEEP_GEMM: + if not has_deep_gemm(): + logger.warning_once("Failed to import DeepGemm kernels.") + elif not self.block_quant: + logger.warning_once("Model is not block quantized. Not using " + " DeepGemm kernels") + elif (current_platform.is_cuda() + and current_platform.has_device_capability(90)): + logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") + self.allow_deep_gemm = True + else: + logger.warning_once( + "DeepGemm not supported on the current platform.") + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -416,6 +455,31 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.weight_block_size[0], + self.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}.") + if (tp_size > 1 + and intermediate_size_per_partition % block_k != 0): + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty( num_experts, @@ -450,10 +514,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # Add PER-TENSOR quantization for FusedMoE.weight_loader. extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + assert self.static_input_scales is False w13_weight_scale = torch.nn.Parameter(torch.ones( num_experts, 2 * intermediate_size_per_partition, @@ -468,8 +530,35 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + elif self.block_quant: + assert self.static_input_scales is False + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // + block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-BLOCK quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: @@ -489,6 +578,59 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.w2_input_scale = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts, shuffle_weights) + + # TODO (rob): refactor block quant into separate class. + if self.block_quant: + assert self.static_input_scales is False + if current_platform.is_fp8_fnuz(): + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + else: + w13_weight = layer.w13_weight.data + w13_weight_scale = layer.w13_weight_scale.data + w2_weight = layer.w2_weight + w2_weight_scale = layer.w2_weight_scale + + # torch.compile() cannot use Parameter subclasses. + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if self.rocm_aiter_moe_enabled: + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + # DeepGemm scales need to be transposed and aligned. We try to do + # it ahead of time for performance reasons. + if self.allow_deep_gemm: + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + if _is_col_major(layer.w13_weight_scale): + layer.w13_weight_scale = \ + dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale).contiguous() + if _is_col_major(layer.w2_weight_scale): + layer.w2_weight_scale = \ + dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale).contiguous() + # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: @@ -556,9 +698,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - rocm_aiter_fused_experts, shuffle_weights) - # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) @@ -571,7 +710,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts else: from vllm.model_executor.layers.fused_moe import fused_experts - self.fused_experts_func = fused_experts + self.fused_experts_func = functools.partial( # type: ignore + fused_experts, + use_fp8_w8a8=True, + block_shape=self.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm) if self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) @@ -633,7 +776,8 @@ def apply( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + a2_scale=layer.w2_input_scale, + block_shape=self.weight_block_size) if self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") @@ -661,7 +805,6 @@ def apply( inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, per_channel_quant=self.weight_quant.strategy == QuantizationStrategy.CHANNEL, global_num_experts=global_num_experts, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 1e61e058cb8..f5ebcde0adf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -4,15 +4,20 @@ from typing import Callable, Optional import torch -from compressed_tensors.quantization import QuantizationStrategy +import torch.nn.functional as F +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy) from torch.nn import Parameter +import vllm.envs as envs +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, - requantize_with_max_scale) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) from vllm.platforms import current_platform @@ -22,17 +27,133 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - def __init__(self, strategy: str, is_static_input_scheme: bool): - self.strategy = strategy + def __init__(self, weight_quant: QuantizationArgs, + is_static_input_scheme: bool): + self.weight_quant = weight_quant + self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + if self.weight_quant.block_structure is None: + block_size = None + elif isinstance(self.weight_quant.block_structure, str): + block_size = [ + int(x) for x in self.weight_quant.block_structure.split("x") + ] + else: + block_size = self.weight_quant.block_structure + self.weight_block_size = block_size + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + # AITER is only supported on ROCm and only for FP8_FNUZ + # and at the moment are MI300 series + self.use_aiter_and_is_supported = (current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) + @classmethod def get_min_capability(cls) -> int: # lovelace and up return 89 + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + weight_loader: Callable, **kwargs): + maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.weight_block_size = None + + if self.strategy == QuantizationStrategy.BLOCK: + tp_size = get_tensor_model_parallel_world_size() + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + block_n, block_k = ( + layer.weight_block_size[0], + layer.weight_block_size[1], + ) + # Required by row parallel + if (tp_size > 1 + and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + # Required by column parallel or enabling merged weights + if (tp_size > 1 and output_size // output_size_per_partition + == tp_size) or len(output_partition_sizes) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}.") + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + # TODO: update create_xxx_parameter functions to return + # the newly added parameters + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + elif self.strategy == QuantizationStrategy.BLOCK: + assert self.is_static_input_scheme is False + weight_scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + def process_weights_after_loading(self, layer) -> None: # If per tensor, when we have a fused module (e.g. QKV) with per # tensor scales (thus N scales being passed to the kernel), @@ -76,8 +197,26 @@ def process_weights_after_loading(self, layer) -> None: else: weight_scale = layer.weight_scale.data + # required by torch.compile to be torch.nn.Parameter layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + elif self.strategy == QuantizationStrategy.BLOCK: + assert self.is_static_input_scheme is False + weight = layer.weight.data + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, _ = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale) + else: + weight_scale = layer.weight_scale.data + + weight = self._maybe_pad_weight(weight) + # required by torch.compile to be torch.nn.Parameter + layer.weight = Parameter(weight, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: @@ -90,58 +229,23 @@ def process_weights_after_loading(self, layer) -> None: else: layer.input_scale = None - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - maybe_create_device_identity() - - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes - - # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) - layer.register_parameter("weight", weight) - - # WEIGHT SCALE - # TODO: update create_xxx_parameter functions to return - # the newly added parameters - if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) - else: - assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - # min requirement for fp8 kernels - weight_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", weight_scale) - - # INPUT SCALE - if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - input_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("input_scale", input_scale) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if layer.weight_block_size is not None: + return torch.ops.vllm.apply_w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=layer.weight_block_size, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 93472207fbb..0d6394b02d1 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -188,7 +188,8 @@ def __init__(self, quant_config: Fp8Config): and envs.VLLM_ROCM_USE_AITER_LINEAR and current_platform.is_fp8_fnuz()) - self.block_quant = self.quant_config.weight_block_size is not None + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant = self.weight_block_size is not None self.fp8_linear = Fp8LinearOp( # Default to using per_token quantization if cutlass is supported use_per_token_if_dynamic=cutlass_fp8_supported()) @@ -215,11 +216,11 @@ def create_weights( if self.block_quant: tp_size = get_tensor_model_parallel_world_size() - assert self.quant_config.weight_block_size is not None - layer.weight_block_size = self.quant_config.weight_block_size + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], + self.weight_block_size[0], + self.weight_block_size[1], ) # Required by row parallel if (tp_size > 1 @@ -399,12 +400,12 @@ def apply(self, bias=bias) if self.block_quant: - assert self.quant_config.weight_block_size is not None + assert layer.weight_block_size is not None return torch.ops.vllm.apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, - block_size=self.quant_config.weight_block_size, + block_size=layer.weight_block_size, weight_scale=layer.weight_scale_inv, input_scale=layer.input_scale, bias=bias, @@ -436,7 +437,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): from vllm.model_executor.layers.fused_moe import fused_experts self.quant_config = quant_config - self.block_quant = self.quant_config.weight_block_size is not None + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant = self.weight_block_size is not None # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization @@ -466,7 +468,7 @@ def __init__(self, quant_config: Fp8Config): self.fused_experts = functools.partial( # type: ignore fused_experts, use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, + block_shape=self.weight_block_size, allow_deep_gemm=self.allow_deep_gemm) def create_weights(self, layer: Module, num_experts: int, hidden_size: int, @@ -482,12 +484,12 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn if self.block_quant: - assert self.quant_config.weight_block_size is not None - layer.weight_block_size = self.quant_config.weight_block_size + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size tp_size = get_tensor_model_parallel_world_size() block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], + self.weight_block_size[0], + self.weight_block_size[1], ) # NOTE: To ensure proper alignment of the block-wise quantization # scales, the output_size of the weights for both the gate and up @@ -793,13 +795,13 @@ def select_gemm_impl(self, prepare_finalize, moe): use_int8_w8a16=False, use_int4_w4a16=False, per_channel_quant=False, - block_shape=self.quant_config.weight_block_size, + block_shape=self.weight_block_size, allow_deep_gemm=self.allow_deep_gemm, ) else: experts = TritonOrDeepGemmExperts( use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, + block_shape=self.weight_block_size, allow_deep_gemm=self.allow_deep_gemm, ) @@ -871,7 +873,7 @@ def apply( if self.block_quant else layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size) + block_shape=self.weight_block_size) elif self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.")