|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 | from copy import deepcopy
|
4 |
| -from dataclasses import dataclass |
5 | 4 | from typing import TYPE_CHECKING
|
6 | 5 |
|
7 |
| -import vllm.envs as envs |
8 |
| -from vllm.distributed import divide |
9 | 6 | from vllm.logger import init_logger
|
10 |
| -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv |
11 |
| -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec |
12 | 7 |
|
13 | 8 | if TYPE_CHECKING:
|
14 |
| - from transformers.configuration_utils import PretrainedConfig |
15 |
| - |
16 | 9 | from vllm.config import VllmConfig
|
17 | 10 |
|
18 | 11 | logger = init_logger(__name__)
|
@@ -198,197 +191,10 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
198 | 191 | }
|
199 | 192 |
|
200 | 193 |
|
201 |
| -class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): |
202 |
| - |
203 |
| - @classmethod |
204 |
| - def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int) -> int: |
205 |
| - """Compute the increase in group numbers to account for |
206 |
| - replication in order to accompany the head shards.""" |
207 |
| - |
208 |
| - # in the case ngoups % tp_size == 0, this will be zero |
209 |
| - if ngroups % tp_size == 0: |
210 |
| - return 0 |
211 |
| - |
212 |
| - # for n_groups == 1, this is exactly tp_size - n_groups |
213 |
| - return tp_size - ngroups |
214 |
| - |
215 |
| - @dataclass |
216 |
| - class MambaConfig: |
217 |
| - expand: int |
218 |
| - n_groups: int |
219 |
| - n_heads: int |
220 |
| - d_head: int |
221 |
| - d_state: int |
222 |
| - d_conv: int |
223 |
| - |
224 |
| - @classmethod |
225 |
| - def parse_mamba_config(cls, config: "PretrainedConfig") -> MambaConfig: |
226 |
| - return cls.MambaConfig( |
227 |
| - expand=config.mamba_expand, |
228 |
| - n_groups=config.mamba_n_groups, |
229 |
| - n_heads=config.mamba_n_heads, |
230 |
| - d_head=config.mamba_d_head, |
231 |
| - d_state=config.mamba_d_state, |
232 |
| - d_conv=config.mamba_d_conv, |
233 |
| - ) |
234 |
| - |
235 |
| - @classmethod |
236 |
| - def get_mamba_cache_shape( |
237 |
| - cls, vllm_config: "VllmConfig" |
238 |
| - ) -> tuple[tuple[int, int], tuple[int, int]]: |
239 |
| - |
240 |
| - parallel_config = vllm_config.parallel_config |
241 |
| - hf_config = vllm_config.model_config.hf_config |
242 |
| - mamba_config = cls.parse_mamba_config(hf_config) |
243 |
| - |
244 |
| - world_size = parallel_config.tensor_parallel_size |
245 |
| - hidden_size = hf_config.hidden_size |
246 |
| - intermediate_size = mamba_config.expand * hidden_size |
247 |
| - |
248 |
| - # if n_groups is not divisible by world_size, need to extend the shards |
249 |
| - # to ensure all groups needed by a head is sharded along with it |
250 |
| - n_groups = (mamba_config.n_groups + cls.extra_groups_for_head_shards( |
251 |
| - mamba_config.n_groups, world_size)) |
252 |
| - |
253 |
| - # - heads and n_groups are TP-ed |
254 |
| - conv_dim = (intermediate_size + 2 * n_groups * mamba_config.d_state) |
255 |
| - conv_state_shape = ( |
256 |
| - divide(conv_dim, world_size), |
257 |
| - mamba_config.d_conv - 1, |
258 |
| - ) |
259 |
| - |
260 |
| - # These are not TP-ed as they depend on A, dt_bias, D |
261 |
| - # - they are typically small |
262 |
| - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) |
263 |
| - temporal_state_shape = ( |
264 |
| - divide(mamba_config.n_heads, world_size), |
265 |
| - mamba_config.d_head, |
266 |
| - mamba_config.d_state, |
267 |
| - ) |
268 |
| - |
269 |
| - return conv_state_shape, temporal_state_shape |
270 |
| - |
271 |
| - @classmethod |
272 |
| - def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: |
273 |
| - """ |
274 |
| - Ensure that page size of attention layers is greater than or |
275 |
| - equal to the mamba layers. If not, automatically set the attention |
276 |
| - block size to ensure that it is. If the attention page size is |
277 |
| - strictly greater than the mamba page size, we pad the mamba page size |
278 |
| - to make them equal. |
279 |
| -
|
280 |
| - Args: |
281 |
| - vllm_config: vLLM Config |
282 |
| - """ |
283 |
| - |
284 |
| - if not envs.VLLM_USE_V1: |
285 |
| - return |
286 |
| - |
287 |
| - cache_config = vllm_config.cache_config |
288 |
| - model_config = vllm_config.model_config |
289 |
| - parallel_config = vllm_config.parallel_config |
290 |
| - |
291 |
| - if cache_config.cache_dtype == "auto": |
292 |
| - kv_cache_dtype = model_config.dtype |
293 |
| - else: |
294 |
| - kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] |
295 |
| - |
296 |
| - # get attention page size (for 1 token) |
297 |
| - attn_page_size_1_token = FullAttentionSpec( |
298 |
| - block_size=1, |
299 |
| - num_kv_heads=model_config.get_num_kv_heads(parallel_config), |
300 |
| - head_size=model_config.get_head_size(), |
301 |
| - dtype=kv_cache_dtype, |
302 |
| - use_mla=model_config.use_mla).page_size_bytes |
303 |
| - |
304 |
| - # get mamba page size |
305 |
| - mamba_page_size = MambaSpec( |
306 |
| - shapes=cls.get_mamba_cache_shape(vllm_config), |
307 |
| - dtype=kv_cache_dtype, |
308 |
| - block_size=model_config.max_model_len, |
309 |
| - ).page_size_bytes |
310 |
| - |
311 |
| - # some attention backends (e.g. FA) only support setting |
312 |
| - # block size to multiple of 16, so let's suggest a value |
313 |
| - # that would work (note: FA is currently not compatible |
314 |
| - # with mamba layers, use FlashInfer instead). |
315 |
| - attn_block_size = 16 * cdiv(mamba_page_size, |
316 |
| - 16 * attn_page_size_1_token) |
317 |
| - |
318 |
| - # override attention block size if either (a) the |
319 |
| - # user has not set it or (b) the user has set it |
320 |
| - # too small. |
321 |
| - if (cache_config.block_size is None |
322 |
| - or cache_config.block_size < attn_block_size): |
323 |
| - cache_config.block_size = attn_block_size |
324 |
| - logger.info( |
325 |
| - "Setting attention block size to %d tokens " |
326 |
| - "to ensure that attention page size is >= mamba page size.", |
327 |
| - attn_block_size) |
328 |
| - |
329 |
| - # compute new attention page size |
330 |
| - attn_page_size = \ |
331 |
| - cache_config.block_size * attn_page_size_1_token |
332 |
| - |
333 |
| - assert attn_page_size >= mamba_page_size |
334 |
| - |
335 |
| - if attn_page_size == mamba_page_size: |
336 |
| - # don't need to pad mamba page size |
337 |
| - return |
338 |
| - |
339 |
| - # pad mamba page size to exactly match attention |
340 |
| - if (cache_config.mamba_page_size_padded is None |
341 |
| - or cache_config.mamba_page_size_padded != attn_page_size): |
342 |
| - cache_config.mamba_page_size_padded = (attn_page_size) |
343 |
| - mamba_padding_pct = 100 * (attn_page_size - |
344 |
| - mamba_page_size) / mamba_page_size |
345 |
| - logger.info( |
346 |
| - "Padding mamba page size by %.2f%% to ensure " |
347 |
| - "that mamba page size and attention page size are " |
348 |
| - "exactly equal.", mamba_padding_pct) |
349 |
| - |
350 |
| - |
351 |
| -class NemotronHModelConfig(HybridAttentionMambaModelConfig): |
352 |
| - |
353 |
| - @classmethod |
354 |
| - def parse_mamba_config( |
355 |
| - cls, config: "PretrainedConfig" |
356 |
| - ) -> HybridAttentionMambaModelConfig.MambaConfig: |
357 |
| - return HybridAttentionMambaModelConfig.MambaConfig( |
358 |
| - expand=config.expand, |
359 |
| - n_groups=config.n_groups, |
360 |
| - n_heads=config.mamba_num_heads, |
361 |
| - d_head=config.mamba_head_dim, |
362 |
| - d_state=config.ssm_state_size, |
363 |
| - d_conv=config.conv_kernel, |
364 |
| - ) |
365 |
| - |
366 |
| - |
367 |
| -class Zamba2ModelConfig(HybridAttentionMambaModelConfig): |
368 |
| - |
369 |
| - @classmethod |
370 |
| - def parse_mamba_config( |
371 |
| - cls, config: "PretrainedConfig" |
372 |
| - ) -> HybridAttentionMambaModelConfig.MambaConfig: |
373 |
| - return HybridAttentionMambaModelConfig.MambaConfig( |
374 |
| - expand=config.mamba_expand, |
375 |
| - n_groups=config.mamba_ngroups, |
376 |
| - n_heads=config.n_mamba_heads, |
377 |
| - d_head=config.mamba_headdim, |
378 |
| - d_state=config.mamba_d_state, |
379 |
| - d_conv=config.mamba_d_conv, |
380 |
| - ) |
381 |
| - |
382 |
| - |
383 | 194 | MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
384 | 195 | "GteModel": SnowflakeGteNewModelConfig,
|
385 | 196 | "GteNewModel": GteNewModelConfig,
|
386 | 197 | "NomicBertModel": NomicBertModelConfig,
|
387 | 198 | "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
|
388 | 199 | "XLMRobertaModel": JinaRobertaModelConfig,
|
389 |
| - "FalconH1ForCausalLM": HybridAttentionMambaModelConfig, |
390 |
| - "BambaForCausalLM": HybridAttentionMambaModelConfig, |
391 |
| - "GraniteMoeHybridForCausalLM": HybridAttentionMambaModelConfig, |
392 |
| - "NemotronHForCausalLM": NemotronHModelConfig, |
393 |
| - "Zamba2ForCausalLM": Zamba2ModelConfig, |
394 | 200 | }
|
0 commit comments