diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ba1498e6531..ede3a5070f3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -4,6 +4,9 @@ import functools import json import os +import warnings +from dataclasses import dataclass, field +from enum import Enum from typing import Any, Callable, Optional import torch @@ -29,6 +32,242 @@ logger = init_logger(__name__) +class QuantizationType(Enum): + """Supported quantization types for MoE layers.""" + NONE = "none" + FP8_W8A8 = "fp8_w8a8" + INT8_W8A8 = "int8_w8a8" + INT8_W8A16 = "int8_w8a16" + INT4_W4A16 = "int4_w4a16" + + +@dataclass +class FusedMoeQuantConfig: + """Configuration for FusedMoE quantization settings. + + This class encapsulates quantization-related parameters for MoE layers, + providing a clean interface and preventing conflicting configurations. + + Args: + quantization_type: The type of quantization to use + activation_dtype: Data type for activations (auto-inferred if None) + per_channel_quant: Whether to use per-channel quantization + block_shape: Block dimensions for block-wise quantization + + Examples: + >>> # Create FP8 quantization config + >>> config = FusedMoeQuantConfig.create_fp8_w8a8() + >>> + >>> # Create INT8 weight, FP16 activation config + >>> config = FusedMoeQuantConfig.create_int8_w8a16() + """ + + quantization_type: QuantizationType = QuantizationType.NONE + activation_dtype: Optional[torch.dtype] = None + per_channel_quant: bool = False + block_shape: Optional[list[int]] = None + + # Cached properties for performance (private fields) + _use_fp8_w8a8: Optional[bool] = field(default=None, init=False, repr=False) + _use_int8_w8a8: Optional[bool] = field(default=None, + init=False, + repr=False) + _use_int8_w8a16: Optional[bool] = field(default=None, + init=False, + repr=False) + _use_int4_w4a16: Optional[bool] = field(default=None, + init=False, + repr=False) + + def __post_init__(self): + """Validate configuration and cache properties after initialization.""" + self._validate_config() + self._cache_properties() + + def _validate_config(self): + """Validate that the quantization configuration is valid.""" + # Validate activation dtype for each quantization type + valid_activation_dtypes = { + QuantizationType.NONE: [None], + QuantizationType.FP8_W8A8: [torch.float8_e4m3fn, None], + QuantizationType.INT8_W8A8: [torch.int8, None], + QuantizationType.INT8_W8A16: [torch.float16, torch.bfloat16, None], + QuantizationType.INT4_W4A16: [torch.float16, torch.bfloat16, None], + } + + expected_dtypes = valid_activation_dtypes.get(self.quantization_type, + []) + if self.activation_dtype not in expected_dtypes: + raise ValueError( + f"Invalid activation_dtype {self.activation_dtype} for " + f"{self.quantization_type}. Expected one of: {expected_dtypes}" + ) + + # Auto-infer activation dtype if not specified + if self.activation_dtype is None and self.quantization_type != QuantizationType.NONE: + default_activation_dtypes = { + QuantizationType.FP8_W8A8: torch.float8_e4m3fn, + QuantizationType.INT8_W8A8: torch.int8, + QuantizationType.INT8_W8A16: torch.float16, + QuantizationType.INT4_W4A16: torch.float16, + } + self.activation_dtype = default_activation_dtypes.get( + self.quantization_type) + + # Validate block_shape + if self.block_shape is not None: + if not isinstance(self.block_shape, list) or len( + self.block_shape) != 2: + raise ValueError("block_shape must be a list of two integers") + if not all(isinstance(x, int) and x > 0 for x in self.block_shape): + raise ValueError( + "block_shape values must be positive integers") + + def _cache_properties(self): + """Cache boolean properties for performance in hot paths.""" + self._use_fp8_w8a8 = self.quantization_type == QuantizationType.FP8_W8A8 + self._use_int8_w8a8 = self.quantization_type == QuantizationType.INT8_W8A8 + self._use_int8_w8a16 = self.quantization_type == QuantizationType.INT8_W8A16 + self._use_int4_w4a16 = self.quantization_type == QuantizationType.INT4_W4A16 + + @property + def weight_dtype(self) -> Optional[torch.dtype]: + """Get the weight data type for this quantization configuration.""" + weight_dtype_mapping = { + QuantizationType.NONE: None, + QuantizationType.FP8_W8A8: torch.float8_e4m3fn, + QuantizationType.INT8_W8A8: torch.int8, + QuantizationType.INT8_W8A16: torch.int8, + QuantizationType.INT4_W4A16: + None, # INT4 handled specially in kernels + } + return weight_dtype_mapping.get(self.quantization_type) + + @property + def use_fp8_w8a8(self) -> bool: + """Backward compatibility: FP8 weight and activation.""" + return self._use_fp8_w8a8 + + @property + def use_int8_w8a8(self) -> bool: + """Backward compatibility: INT8 weight and activation.""" + return self._use_int8_w8a8 + + @property + def use_int8_w8a16(self) -> bool: + """Backward compatibility: INT8 weight, FP16/BF16 activation.""" + return self._use_int8_w8a16 + + @property + def use_int4_w4a16(self) -> bool: + """Backward compatibility: INT4 weight, FP16/BF16 activation.""" + return self._use_int4_w4a16 + + @property + def is_quantized(self) -> bool: + """Check if any quantization is enabled.""" + return self.quantization_type != QuantizationType.NONE + + @classmethod + def create_fp8_w8a8( + cls, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None) -> 'FusedMoeQuantConfig': + """Factory method for FP8 weight and activation quantization.""" + return cls(quantization_type=QuantizationType.FP8_W8A8, + per_channel_quant=per_channel_quant, + block_shape=block_shape) + + @classmethod + def create_int8_w8a8( + cls, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None) -> 'FusedMoeQuantConfig': + """Factory method for INT8 weight and activation quantization.""" + return cls(quantization_type=QuantizationType.INT8_W8A8, + per_channel_quant=per_channel_quant, + block_shape=block_shape) + + @classmethod + def create_int8_w8a16( + cls, + activation_dtype: torch.dtype = torch.float16, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None) -> 'FusedMoeQuantConfig': + """Factory method for INT8 weight, FP16/BF16 activation quantization.""" + if activation_dtype not in [torch.float16, torch.bfloat16]: + raise ValueError( + "activation_dtype must be torch.float16 or torch.bfloat16") + return cls(quantization_type=QuantizationType.INT8_W8A16, + activation_dtype=activation_dtype, + per_channel_quant=per_channel_quant, + block_shape=block_shape) + + @classmethod + def create_int4_w4a16( + cls, + activation_dtype: torch.dtype = torch.float16, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None) -> 'FusedMoeQuantConfig': + """Factory method for INT4 weight, FP16/BF16 activation quantization.""" + if activation_dtype not in [torch.float16, torch.bfloat16]: + raise ValueError( + "activation_dtype must be torch.float16 or torch.bfloat16") + return cls(quantization_type=QuantizationType.INT4_W4A16, + activation_dtype=activation_dtype, + per_channel_quant=per_channel_quant, + block_shape=block_shape) + + @classmethod + def create_no_quant(cls) -> 'FusedMoeQuantConfig': + """Factory method for no quantization (default floating point).""" + return cls(quantization_type=QuantizationType.NONE) + + @classmethod + def from_legacy_flags( + cls, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None) -> 'FusedMoeQuantConfig': + """Create config from legacy boolean flags for backward compatibility. + + Warning: + This method is deprecated and will be removed in a future version. + Use factory methods like create_fp8_w8a8() instead. + """ + # Issue deprecation warning for legacy usage + if any([use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16]): + warnings.warn( + "Using legacy quantization flags (use_fp8_w8a8, use_int8_w8a8, etc.) " + "is deprecated. Please use FusedMoeQuantConfig factory methods instead " + "(e.g., FusedMoeQuantConfig.create_fp8_w8a8()). " + "Legacy support will be removed in vLLM v0.7.0.", + FutureWarning, + stacklevel=3) + + # Validate that only one quantization type is enabled + flags = [use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16] + if sum(flags) > 1: + raise ValueError( + "Only one quantization type can be enabled at a time") + + if use_fp8_w8a8: + return cls.create_fp8_w8a8(per_channel_quant, block_shape) + elif use_int8_w8a8: + return cls.create_int8_w8a8(per_channel_quant, block_shape) + elif use_int8_w8a16: + return cls.create_int8_w8a16(torch.float16, per_channel_quant, + block_shape) + elif use_int4_w4a16: + return cls.create_int4_w4a16(torch.float16, per_channel_quant, + block_shape) + else: + return cls.create_no_quant() + + @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, @@ -462,40 +701,56 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -def invoke_fused_moe_kernel(A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor], - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: dict[str, Any], - compute_type: tl.dtype, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[list[int]] = None) -> None: +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, + fused_moe_quant_config: Optional[FusedMoeQuantConfig] = None, + # Deprecated: keep for backward compatibility + use_fp8_w8a8: Optional[bool] = None, + use_int8_w8a8: Optional[bool] = None, + use_int8_w8a16: Optional[bool] = None, + use_int4_w4a16: Optional[bool] = None, + per_channel_quant: Optional[bool] = None, + block_shape: Optional[list[int]] = None) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - if use_fp8_w8a8 or use_int8_w8a8: + # Handle backward compatibility: create config from legacy flags if needed + if fused_moe_quant_config is None: + fused_moe_quant_config = FusedMoeQuantConfig.from_legacy_flags( + use_fp8_w8a8=use_fp8_w8a8 or False, + use_int8_w8a8=use_int8_w8a8 or False, + use_int8_w8a16=use_int8_w8a16 or False, + use_int4_w4a16=use_int4_w4a16 or False, + per_channel_quant=per_channel_quant or False, + block_shape=block_shape) + + if fused_moe_quant_config.use_fp8_w8a8 or fused_moe_quant_config.use_int8_w8a8: assert B_scale is not None - assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) + assert (fused_moe_quant_config.block_shape is None or triton.cdiv( + B.shape[-2], fused_moe_quant_config.block_shape[0]) == B_scale.shape[-2]) - assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) + assert (fused_moe_quant_config.block_shape is None or triton.cdiv( + B.shape[-1], fused_moe_quant_config.block_shape[1]) == B_scale.shape[-1]) - elif use_int8_w8a16 or use_int4_w4a16: + elif fused_moe_quant_config.use_int8_w8a16 or fused_moe_quant_config.use_int4_w4a16: assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 + assert fused_moe_quant_config.block_shape is None or fused_moe_quant_config.block_shape[ + 0] == 0 else: assert A_scale is None assert B_scale is None @@ -514,30 +769,31 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) - if (use_int8_w8a16 or use_int4_w4a16) and \ - block_shape is not None and block_shape[1] > 0: + if (fused_moe_quant_config.use_int8_w8a16 or fused_moe_quant_config.use_int4_w4a16) and \ + fused_moe_quant_config.block_shape is not None and fused_moe_quant_config.block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=num_tokens, - group_size=block_shape[1], + group_size=fused_moe_quant_config.block_shape[1], num_experts=B.shape[0], - bit=4 if use_int4_w4a16 else 8) + bit=4 if fused_moe_quant_config.use_int4_w4a16 else 8) config = config.copy() config.update( - get_moe_wna16_block_config(config=config, - use_moe_wna16_cuda=use_moe_wna16_cuda, - num_valid_tokens=num_tokens, - size_k=A.shape[1], - size_n=B.shape[1], - num_experts=B.shape[1], - group_size=block_shape[1], - real_top_k=top_k, - block_size_m=config["BLOCK_SIZE_M"])) + get_moe_wna16_block_config( + config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=num_tokens, + size_k=A.shape[1], + size_n=B.shape[1], + num_experts=B.shape[1], + group_size=fused_moe_quant_config.block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"])) if use_moe_wna16_cuda: - bit = 4 if use_int4_w4a16 else 8 + bit = 4 if fused_moe_quant_config.use_int4_w4a16 else 8 ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, topk_weights if mul_routed_weight else None, sorted_token_ids, expert_ids, @@ -574,21 +830,23 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0, block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, - group_size=block_shape[1], + group_size=fused_moe_quant_config.block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, has_zp=B_zp is not None, - use_int4_w4a16=use_int4_w4a16, - use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=fused_moe_quant_config.use_int4_w4a16, + use_int8_w8a16=fused_moe_quant_config.use_int8_w8a16, **config, ) else: config = config.copy() BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") - if block_shape is not None: - BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], - block_shape[1])) + if fused_moe_quant_config.block_shape is not None: + BLOCK_SIZE_K = min( + BLOCK_SIZE_K, + min(fused_moe_quant_config.block_shape[0], + fused_moe_quant_config.block_shape[1])) fused_moe_kernel[grid]( A, B, @@ -620,15 +878,17 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if B_scale is not None and B_scale.ndim == 3 else 0, B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], + 0 if fused_moe_quant_config.block_shape is None else + fused_moe_quant_config.block_shape[0], + 0 if fused_moe_quant_config.block_shape is None else + fused_moe_quant_config.block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - per_channel_quant=per_channel_quant, + use_fp8_w8a8=fused_moe_quant_config.use_fp8_w8a8, + use_int8_w8a8=fused_moe_quant_config.use_int8_w8a8, + use_int8_w8a16=fused_moe_quant_config.use_int8_w8a16, + per_channel_quant=fused_moe_quant_config.per_channel_quant, BLOCK_SIZE_K=BLOCK_SIZE_K, **config, ) @@ -1137,33 +1397,46 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: return torch_vllm_outplace_fused_experts -def fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + fused_moe_quant_config: Optional[FusedMoeQuantConfig] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + allow_deep_gemm: bool = False, + # Deprecated: keep for backward compatibility + use_fp8_w8a8: Optional[bool] = None, + use_int8_w8a8: Optional[bool] = None, + use_int8_w8a16: Optional[bool] = None, + use_int4_w4a16: Optional[bool] = None, + per_channel_quant: Optional[bool] = None, + block_shape: Optional[list[int]] = None) -> torch.Tensor: + # Handle backward compatibility + if fused_moe_quant_config is None: + fused_moe_quant_config = FusedMoeQuantConfig.from_legacy_flags( + use_fp8_w8a8=use_fp8_w8a8 or False, + use_int8_w8a8=use_int8_w8a8 or False, + use_int8_w8a16=use_int8_w8a16 or False, + use_int4_w4a16=use_int4_w4a16 or False, + per_channel_quant=per_channel_quant or False, + block_shape=block_shape) + # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. N = w1.shape[1] - if (allow_deep_gemm and use_fp8_w8a8 and N > 512 + if (allow_deep_gemm and fused_moe_quant_config.use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2)): assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( @@ -1191,11 +1464,11 @@ def fused_experts(hidden_states: torch.Tensor, topk_ids=topk_ids, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, + use_fp8_w8a8=fused_moe_quant_config.use_fp8_w8a8, + use_int8_w8a8=fused_moe_quant_config.use_int8_w8a8, + use_int8_w8a16=fused_moe_quant_config.use_int8_w8a16, + use_int4_w4a16=fused_moe_quant_config.use_int4_w4a16, + per_channel_quant=fused_moe_quant_config.per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=w1_scale, @@ -1204,7 +1477,7 @@ def fused_experts(hidden_states: torch.Tensor, w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=block_shape) + block_shape=fused_moe_quant_config.block_shape) def fused_experts_impl( @@ -1216,11 +1489,7 @@ def fused_experts_impl( inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + fused_moe_quant_config: Optional[FusedMoeQuantConfig] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1229,10 +1498,26 @@ def fused_experts_impl( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + # Deprecated: keep for backward compatibility + use_fp8_w8a8: Optional[bool] = None, + use_int8_w8a8: Optional[bool] = None, + use_int8_w8a16: Optional[bool] = None, + use_int4_w4a16: Optional[bool] = None, + per_channel_quant: Optional[bool] = None, block_shape: Optional[list[int]] = None, ) -> torch.Tensor: + # Handle backward compatibility + if fused_moe_quant_config is None: + fused_moe_quant_config = FusedMoeQuantConfig.from_legacy_flags( + use_fp8_w8a8=use_fp8_w8a8 or False, + use_int8_w8a8=use_int8_w8a8 or False, + use_int8_w8a16=use_int8_w8a16 or False, + use_int4_w4a16=use_int4_w4a16 or False, + per_channel_quant=per_channel_quant or False, + block_shape=block_shape) + # Check constraints. - if use_int4_w4a16: + if fused_moe_quant_config.use_int4_w4a16: assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: @@ -1417,11 +1702,7 @@ def fused_moe( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, + fused_moe_quant_config: Optional[FusedMoeQuantConfig] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1430,6 +1711,12 @@ def fused_moe( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + # Deprecated: keep for backward compatibility + use_fp8_w8a8: Optional[bool] = None, + use_int8_w8a8: Optional[bool] = None, + use_int8_w8a16: Optional[bool] = None, + use_int4_w4a16: Optional[bool] = None, + per_channel_quant: Optional[bool] = None, block_shape: Optional[list[int]] = None, ) -> torch.Tensor: """ @@ -1482,6 +1769,16 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ + # Handle backward compatibility + if fused_moe_quant_config is None: + fused_moe_quant_config = FusedMoeQuantConfig.from_legacy_flags( + use_fp8_w8a8=use_fp8_w8a8 or False, + use_int8_w8a8=use_int8_w8a8 or False, + use_int8_w8a16=use_int8_w8a16 or False, + use_int4_w4a16=use_int4_w4a16 or False, + per_channel_quant=per_channel_quant or False, + block_shape=block_shape) + if use_grouped_topk: assert num_expert_group is not None and topk_group is not None topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, @@ -1501,11 +1798,7 @@ def fused_moe( topk_ids, inplace=inplace, activation=activation, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, + fused_moe_quant_config=fused_moe_quant_config, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=w1_scale, @@ -1513,34 +1806,42 @@ def fused_moe( w1_zp=w1_zp, w2_zp=w2_zp, a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) + a2_scale=a2_scale) class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[list[int]] = None, + fused_moe_quant_config: Optional[FusedMoeQuantConfig] = None, block_m: Optional[int] = None, + # Deprecated: keep for backward compatibility + use_fp8_w8a8: Optional[bool] = None, + use_int8_w8a8: Optional[bool] = None, + use_int8_w8a16: Optional[bool] = None, + use_int4_w4a16: Optional[bool] = None, + per_channel_quant: Optional[bool] = None, + block_shape: Optional[list[int]] = None, ): super().__init__() - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int4_w4a16 = use_int4_w4a16 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int8_w8a16 = use_int8_w8a16 - self.block_shape = block_shape + + # Handle backward compatibility + if fused_moe_quant_config is None: + fused_moe_quant_config = FusedMoeQuantConfig.from_legacy_flags( + use_fp8_w8a8=use_fp8_w8a8 or False, + use_int8_w8a8=use_int8_w8a8 or False, + use_int8_w8a16=use_int8_w8a16 or False, + use_int4_w4a16=use_int4_w4a16 or False, + per_channel_quant=per_channel_quant or False, + block_shape=block_shape) + + self.fused_moe_quant_config = fused_moe_quant_config self.block_m = block_m - self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16) - self.per_channel_quant = per_channel_quant + self.qtype = get_config_qtype( + use_fp8_w8a8=fused_moe_quant_config.use_fp8_w8a8, + use_int8_w8a8=fused_moe_quant_config.use_int8_w8a8, + use_int8_w8a16=fused_moe_quant_config.use_int8_w8a16, + use_int4_w4a16=fused_moe_quant_config.use_int4_w4a16) def workspace_shapes( self, @@ -1577,7 +1878,7 @@ def apply( expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: # Check constraints. - if self.use_int4_w4a16: + if self.fused_moe_quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( "Hidden size mismatch") else: @@ -1600,10 +1901,11 @@ def apply( if global_num_experts == -1: global_num_experts = E - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - dtype=hidden_states.dtype) + config_dtype = get_config_dtype_str( + use_fp8_w8a8=self.fused_moe_quant_config.use_fp8_w8a8, + use_int8_w8a16=self.fused_moe_quant_config.use_int8_w8a16, + use_int4_w4a16=self.fused_moe_quant_config.use_int4_w4a16, + dtype=hidden_states.dtype) config = try_get_optimal_moe_config( w1.shape, @@ -1611,7 +1913,7 @@ def apply( top_k_num, config_dtype, num_tokens, - block_shape=self.block_shape, + block_shape=self.fused_moe_quant_config.block_shape, ) if hidden_states.dtype == torch.bfloat16: @@ -1639,26 +1941,22 @@ def apply( moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - invoke_fused_moe_kernel(hidden_states, - w1, - intermediate_cache1, - a1q_scale, - w1_scale, - w1_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape) + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + top_k_num, + config, + compute_type=compute_type, + fused_moe_quant_config=self.fused_moe_quant_config) self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1666,59 +1964,61 @@ def apply( a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant, - self.block_shape) - - invoke_fused_moe_kernel(qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape) + intermediate_cache2, a2_scale, self.qtype, + self.fused_moe_quant_config.per_channel_quant, + self.fused_moe_quant_config.block_shape) + + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + fused_moe_quant_config=self.fused_moe_quant_config) return intermediate_cache3 def modular_triton_fused_moe( - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, + fused_moe_quant_config: Optional[FusedMoeQuantConfig] = None, + # Deprecated: keep for backward compatibility + use_fp8_w8a8: Optional[bool] = None, + use_int8_w8a8: Optional[bool] = None, + use_int8_w8a16: Optional[bool] = None, + use_int4_w4a16: Optional[bool] = None, + per_channel_quant: Optional[bool] = None, block_shape: Optional[list[int]] = None, ) -> mk.FusedMoEModularKernel: + # Handle backward compatibility + if fused_moe_quant_config is None: + fused_moe_quant_config = FusedMoeQuantConfig.from_legacy_flags( + use_fp8_w8a8=use_fp8_w8a8 or False, + use_int8_w8a8=use_int8_w8a8 or False, + use_int8_w8a16=use_int8_w8a16 or False, + use_int4_w4a16=use_int4_w4a16 or False, + per_channel_quant=per_channel_quant or False, + block_shape=block_shape) + qtype = get_config_qtype( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, + use_fp8_w8a8=fused_moe_quant_config.use_fp8_w8a8, + use_int8_w8a8=fused_moe_quant_config.use_int8_w8a8, + use_int8_w8a16=fused_moe_quant_config.use_int8_w8a16, + use_int4_w4a16=fused_moe_quant_config.use_int4_w4a16, ) return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP( quant_dtype=qtype, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - ), - TritonExperts( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape, + per_channel_quant=fused_moe_quant_config.per_channel_quant, + block_shape=fused_moe_quant_config.block_shape, ), + TritonExperts(fused_moe_quant_config=fused_moe_quant_config, ), ) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 16a9f0959b5..5da481baeee 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -101,7 +101,10 @@ def init_device(self): # fix this. It will be removed after the bug in XLA compiler is fixed. os.environ["LIBTPU_INIT_ARGS"] = ( os.environ.get("LIBTPU_INIT_ARGS", "") + - " --xla_tpu_force_1d_allreduce_at_chunk_count=1") + " --xla_tpu_force_1d_allreduce_at_chunk_count=1" + " --xla_jf_conv_input_fusion=False") + # --xla_jf_conv_input_fusion=False is used to improve the perf of + # quantized matmul. torch.set_grad_enabled(False) torch.set_default_dtype(self.model_config.dtype)