|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 | from copy import deepcopy
|
| 4 | +from dataclasses import dataclass |
4 | 5 | from typing import TYPE_CHECKING
|
5 | 6 |
|
| 7 | +import vllm.envs as envs |
| 8 | +from vllm.distributed import divide |
6 | 9 | from vllm.logger import init_logger
|
| 10 | +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv |
| 11 | +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec |
7 | 12 |
|
8 | 13 | if TYPE_CHECKING:
|
| 14 | + from transformers.configuration_utils import PretrainedConfig |
| 15 | + |
9 | 16 | from vllm.config import VllmConfig
|
10 | 17 |
|
11 | 18 | logger = init_logger(__name__)
|
@@ -200,11 +207,198 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
200 | 207 | }
|
201 | 208 |
|
202 | 209 |
|
| 210 | +class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): |
| 211 | + |
| 212 | + @classmethod |
| 213 | + def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int) -> int: |
| 214 | + """Compute the increase in group numbers to account for |
| 215 | + replication in order to accompany the head shards.""" |
| 216 | + |
| 217 | + # in the case ngoups % tp_size == 0, this will be zero |
| 218 | + if ngroups % tp_size == 0: |
| 219 | + return 0 |
| 220 | + |
| 221 | + # for n_groups == 1, this is exactly tp_size - n_groups |
| 222 | + return tp_size - ngroups |
| 223 | + |
| 224 | + @dataclass |
| 225 | + class MambaConfig: |
| 226 | + expand: int |
| 227 | + n_groups: int |
| 228 | + n_heads: int |
| 229 | + d_head: int |
| 230 | + d_state: int |
| 231 | + d_conv: int |
| 232 | + |
| 233 | + @classmethod |
| 234 | + def parse_mamba_config(cls, config: "PretrainedConfig") -> MambaConfig: |
| 235 | + return cls.MambaConfig( |
| 236 | + expand=config.mamba_expand, |
| 237 | + n_groups=config.mamba_n_groups, |
| 238 | + n_heads=config.mamba_n_heads, |
| 239 | + d_head=config.mamba_d_head, |
| 240 | + d_state=config.mamba_d_state, |
| 241 | + d_conv=config.mamba_d_conv, |
| 242 | + ) |
| 243 | + |
| 244 | + @classmethod |
| 245 | + def get_mamba_cache_shape( |
| 246 | + cls, vllm_config: "VllmConfig" |
| 247 | + ) -> tuple[tuple[int, int], tuple[int, int]]: |
| 248 | + |
| 249 | + parallel_config = vllm_config.parallel_config |
| 250 | + hf_config = vllm_config.model_config.hf_config |
| 251 | + mamba_config = cls.parse_mamba_config(hf_config) |
| 252 | + |
| 253 | + world_size = parallel_config.tensor_parallel_size |
| 254 | + hidden_size = hf_config.hidden_size |
| 255 | + intermediate_size = mamba_config.expand * hidden_size |
| 256 | + |
| 257 | + # if n_groups is not divisible by world_size, need to extend the shards |
| 258 | + # to ensure all groups needed by a head is sharded along with it |
| 259 | + n_groups = (mamba_config.n_groups + cls.extra_groups_for_head_shards( |
| 260 | + mamba_config.n_groups, world_size)) |
| 261 | + |
| 262 | + # - heads and n_groups are TP-ed |
| 263 | + conv_dim = (intermediate_size + 2 * n_groups * mamba_config.d_state) |
| 264 | + conv_state_shape = ( |
| 265 | + divide(conv_dim, world_size), |
| 266 | + mamba_config.d_conv - 1, |
| 267 | + ) |
| 268 | + |
| 269 | + # These are not TP-ed as they depend on A, dt_bias, D |
| 270 | + # - they are typically small |
| 271 | + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) |
| 272 | + temporal_state_shape = ( |
| 273 | + divide(mamba_config.n_heads, world_size), |
| 274 | + mamba_config.d_head, |
| 275 | + mamba_config.d_state, |
| 276 | + ) |
| 277 | + |
| 278 | + return conv_state_shape, temporal_state_shape |
| 279 | + |
| 280 | + @classmethod |
| 281 | + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: |
| 282 | + """ |
| 283 | + Ensure that page size of attention layers is greater than or |
| 284 | + equal to the mamba layers. If not, automatically set the attention |
| 285 | + block size to ensure that it is. If the attention page size is |
| 286 | + strictly greater than the mamba page size, we pad the mamba page size |
| 287 | + to make them equal. |
| 288 | +
|
| 289 | + Args: |
| 290 | + vllm_config: vLLM Config |
| 291 | + """ |
| 292 | + |
| 293 | + if not envs.VLLM_USE_V1: |
| 294 | + return |
| 295 | + |
| 296 | + cache_config = vllm_config.cache_config |
| 297 | + model_config = vllm_config.model_config |
| 298 | + parallel_config = vllm_config.parallel_config |
| 299 | + |
| 300 | + if cache_config.cache_dtype == "auto": |
| 301 | + kv_cache_dtype = model_config.dtype |
| 302 | + else: |
| 303 | + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] |
| 304 | + |
| 305 | + # get attention page size (for 1 token) |
| 306 | + attn_page_size_1_token = FullAttentionSpec( |
| 307 | + block_size=1, |
| 308 | + num_kv_heads=model_config.get_num_kv_heads(parallel_config), |
| 309 | + head_size=model_config.get_head_size(), |
| 310 | + dtype=kv_cache_dtype, |
| 311 | + use_mla=model_config.use_mla).page_size_bytes |
| 312 | + |
| 313 | + # get mamba page size |
| 314 | + mamba_page_size = MambaSpec( |
| 315 | + shapes=cls.get_mamba_cache_shape(vllm_config), |
| 316 | + dtype=kv_cache_dtype, |
| 317 | + block_size=model_config.max_model_len, |
| 318 | + ).page_size_bytes |
| 319 | + |
| 320 | + # some attention backends (e.g. FA) only support setting |
| 321 | + # block size to multiple of 16, so let's suggest a value |
| 322 | + # that would work (note: FA is currently not compatible |
| 323 | + # with mamba layers, use FlashInfer instead). |
| 324 | + attn_block_size = 16 * cdiv(mamba_page_size, |
| 325 | + 16 * attn_page_size_1_token) |
| 326 | + |
| 327 | + # override attention block size if either (a) the |
| 328 | + # user has not set it or (b) the user has set it |
| 329 | + # too small. |
| 330 | + if (cache_config.block_size is None |
| 331 | + or cache_config.block_size < attn_block_size): |
| 332 | + cache_config.block_size = attn_block_size |
| 333 | + logger.info( |
| 334 | + "Setting attention block size to %d tokens " |
| 335 | + "to ensure that attention page size is >= mamba page size.", |
| 336 | + attn_block_size) |
| 337 | + |
| 338 | + # compute new attention page size |
| 339 | + attn_page_size = \ |
| 340 | + cache_config.block_size * attn_page_size_1_token |
| 341 | + |
| 342 | + assert attn_page_size >= mamba_page_size |
| 343 | + |
| 344 | + if attn_page_size == mamba_page_size: |
| 345 | + # don't need to pad mamba page size |
| 346 | + return |
| 347 | + |
| 348 | + # pad mamba page size to exactly match attention |
| 349 | + if (cache_config.mamba_page_size_padded is None |
| 350 | + or cache_config.mamba_page_size_padded != attn_page_size): |
| 351 | + cache_config.mamba_page_size_padded = (attn_page_size) |
| 352 | + mamba_padding_pct = 100 * (attn_page_size - |
| 353 | + mamba_page_size) / mamba_page_size |
| 354 | + logger.info( |
| 355 | + "Padding mamba page size by %.2f%% to ensure " |
| 356 | + "that mamba page size and attention page size are " |
| 357 | + "exactly equal.", mamba_padding_pct) |
| 358 | + |
| 359 | + |
| 360 | +class NemotronHModelConfig(HybridAttentionMambaModelConfig): |
| 361 | + |
| 362 | + @classmethod |
| 363 | + def parse_mamba_config( |
| 364 | + cls, config: "PretrainedConfig" |
| 365 | + ) -> HybridAttentionMambaModelConfig.MambaConfig: |
| 366 | + return HybridAttentionMambaModelConfig.MambaConfig( |
| 367 | + expand=config.expand, |
| 368 | + n_groups=config.n_groups, |
| 369 | + n_heads=config.mamba_num_heads, |
| 370 | + d_head=config.mamba_head_dim, |
| 371 | + d_state=config.ssm_state_size, |
| 372 | + d_conv=config.conv_kernel, |
| 373 | + ) |
| 374 | + |
| 375 | + |
| 376 | +class Zamba2ModelConfig(HybridAttentionMambaModelConfig): |
| 377 | + |
| 378 | + @classmethod |
| 379 | + def parse_mamba_config( |
| 380 | + cls, config: "PretrainedConfig" |
| 381 | + ) -> HybridAttentionMambaModelConfig.MambaConfig: |
| 382 | + return HybridAttentionMambaModelConfig.MambaConfig( |
| 383 | + expand=config.mamba_expand, |
| 384 | + n_groups=config.mamba_ngroups, |
| 385 | + n_heads=config.n_mamba_heads, |
| 386 | + d_head=config.mamba_headdim, |
| 387 | + d_state=config.mamba_d_state, |
| 388 | + d_conv=config.mamba_d_conv, |
| 389 | + ) |
| 390 | + |
| 391 | + |
203 | 392 | MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
204 | 393 | "GteModel": SnowflakeGteNewModelConfig,
|
205 | 394 | "GteNewModel": GteNewModelConfig,
|
206 | 395 | "NomicBertModel": NomicBertModelConfig,
|
207 | 396 | "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
|
208 | 397 | "XLMRobertaModel": JinaRobertaModelConfig,
|
| 398 | + "FalconH1ForCausalLM": HybridAttentionMambaModelConfig, |
| 399 | + "BambaForCausalLM": HybridAttentionMambaModelConfig, |
| 400 | + "GraniteMoeHybridForCausalLM": HybridAttentionMambaModelConfig, |
| 401 | + "NemotronHForCausalLM": NemotronHModelConfig, |
| 402 | + "Zamba2ForCausalLM": Zamba2ModelConfig, |
209 | 403 | "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
210 | 404 | }
|
0 commit comments