From 3a502236d1ef13219dfc4274444e45aee4d88e5d Mon Sep 17 00:00:00 2001 From: Paul Pak Date: Thu, 10 Jul 2025 19:44:28 +0000 Subject: [PATCH 1/8] [cmake] ignore nvToolsExt for cuda-12.9 --- CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 538f9adcb24..dcb153b7f6a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,13 @@ endif() # find_package(Torch REQUIRED) +# +# Ignore nvToolsExt for cuda-12.9 +# +if (NOT TARGET CUDA::nvToolsExt) + add_library(CUDA::nvToolsExt INTERFACE IMPORTED) +endif() + # Supported NVIDIA architectures. # This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND From 6fd86d2fa3d6476ac7a89643ddc4b51360ea3156 Mon Sep 17 00:00:00 2001 From: Paul Pak Date: Thu, 10 Jul 2025 21:07:42 +0000 Subject: [PATCH 2/8] [model_executor][models] LFM2 architecture --- vllm/model_executor/layers/conv.py | 235 +++++++++ vllm/model_executor/models/conv_cache.py | 71 +++ vllm/model_executor/models/lfm2.py | 577 +++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/utils/__init__.py | 1 + 5 files changed, 885 insertions(+) create mode 100644 vllm/model_executor/layers/conv.py create mode 100644 vllm/model_executor/models/conv_cache.py create mode 100755 vllm/model_executor/models/lfm2.py diff --git a/vllm/model_executor/layers/conv.py b/vllm/model_executor/layers/conv.py new file mode 100644 index 00000000000..7eba01f1d36 --- /dev/null +++ b/vllm/model_executor/layers/conv.py @@ -0,0 +1,235 @@ + +from typing import Any, Optional + +import torch +import torch.nn as nn + +from vllm import envs +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.distributed import divide, get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.model_executor.models.conv_cache import ConvCacheParams +from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata + + +@CustomOp.register("short_conv") +class ShortConv(CustomOp): + + def __init__(self, config, dim: int, layer_idx: int, prefix: str = ""): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.conv_dim = dim + self.L_cache = config.conv_L_cache + self.bias = config.conv_bias + + self.conv = ColumnParallelLinear( + input_size=self.L_cache, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.conv1d", + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv.weight.data = self.conv.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[dim] * 3, + bias=self.bias, + prefix=f"{prefix}.in_proj", + ) + self.out_proj = RowParallelLinear( + input_size=dim, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.out_proj", + ) + + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state,) + self.kv_cache = [(torch.tensor([]))] + + # For compatibility with MambaSpec utils + self.chunk_size = 1 + self.prefix = prefix + + def forward_native(self, hidden_states: torch.Tensor, + conv_cache_params: ConvCacheParams) -> torch.Tensor: + pass + + def forward_cuda( + self, + hidden_states: torch.Tensor, + conv_cache_params: ConvCacheParams, + conv_metadata: Mamba2Metadata, + ) -> torch.Tensor: + forward_context = get_forward_context() + # mamba2_metadata contains metadata necessary for the mamba2 triton + # kernels to operate in continuous batching and in chunked prefill + # modes; they are computed at top-level model forward since they + # stay the same and reused for all mamba layers in the same iteration + attn_metadata: Optional[AttentionMetadata] = get_forward_context().attn_metadata + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states + # prep_initial_states = attn_metadata.prep_initial_states + # chunk_size = attn_metadata.chunk_size + # seq_idx_p = attn_metadata.seq_idx + # chunk_indices_p = attn_metadata.chunk_indices + # chunk_offsets_p = attn_metadata.chunk_offsets + else: + conv_state = conv_cache_params.conv_state + state_indices_tensor = conv_cache_params.state_indices_tensor + has_initial_states_p = conv_metadata.has_initial_states + # prep_initial_states = conv_metadata.prep_initial_states + # chunk_size = conv_metadata.chunk_size + # seq_idx_p = conv_metadata.seq_idx + # chunk_indices_p = conv_metadata.chunk_indices + # chunk_offsets_p = conv_metadata.chunk_offsets + + BCx, _ = self.in_proj(hidden_states) + + B, C, x = BCx.chunk(3, dim=-1) + + conv_weights = self.conv.weight.view(self.conv.weight.size(0), + self.conv.weight.size(2)) + + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + Bx = (B * x).contiguous() + hidden_states = C * Bx + contextualized_states, _ = self.out_proj(hidden_states) + return contextualized_states + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + + # NOTE: V0 put prefill before decode, v1 puts decode before prefill + # Separate prefill and decode by splitting varlen input + # Split along token dimension + if envs.VLLM_USE_V1: + B_d, B_p = torch.split( + B, + [num_decodes, num_prefill_tokens], + dim=0, + ) + C_d, C_p = torch.split( + C, + [num_decodes, num_prefill_tokens], + dim=0, + ) + x_d, x_p = torch.split( + x, + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + else: + B_p, B_d = torch.split( + B, + [num_prefill_tokens, num_decodes], + dim=0, + ) + C_p, C_d = torch.split( + C, + [num_prefill_tokens, num_decodes], + dim=0, + ) + x_p, x_d = torch.split( + x, + [num_prefill_tokens, num_decodes], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + conv_cache_params.state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + + 1] + if has_prefill else None) + + conv_output_list = [] + + if has_prefill: + Bx_p = (B_p * x_p).contiguous() + Bx = causal_conv1d_fn( + Bx_p.transpose(0, 1), + conv_weights, + self.conv.bias, + activation=None, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p).transpose( + 0, 1)[:num_prefill_tokens] + + C_p = C_p.view(1, num_prefill_tokens, -1) + y = C_p * Bx + conv_output_list.append(y.view(num_prefill_tokens, -1)) + + if has_decode: + Bx_d = (B_d * x_d).contiguous() + Bx = causal_conv1d_update( + Bx_d, + conv_state, + conv_weights, + self.conv.bias, + activation=None, + conv_state_indices=state_indices_tensor_d) + C_d = C_d.view(num_decodes, -1) + y = C_d * Bx + conv_output_list.append(y.view(num_decodes, -1)) + + # Merge prefill and decode outputs before passing to gated MLP + hidden_states = torch.vstack(conv_output_list) + + # Final linear projection + contextualized_states, _ = self.out_proj(hidden_states) + + return contextualized_states + + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + world_size = get_tensor_model_parallel_world_size() + conv_state_shape = ( + divide(self.conv_dim, world_size), + self.L_cache - 1, + ) + return (conv_state_shape,) \ No newline at end of file diff --git a/vllm/model_executor/models/conv_cache.py b/vllm/model_executor/models/conv_cache.py new file mode 100644 index 00000000000..0f7f62443d8 --- /dev/null +++ b/vllm/model_executor/models/conv_cache.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + +import torch + +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.model_executor.models.constant_size_cache import ConstantSizeCache + + +@dataclass +class ConvCacheParams: + conv_state: torch.Tensor = torch.Tensor() + state_indices_tensor: torch.Tensor = torch.Tensor() + + def at_layer_idx(self, layer_idx): + return ConvCacheParams(self.conv_state[layer_idx], + self.state_indices_tensor) + + +class ConvCacheManager(ConstantSizeCache): + + def __init__( + self, + vllm_config: VllmConfig, + dtype: torch.dtype, + num_conv_layers: int, + conv_state_shape: tuple[int, int]): + + max_batch_size = vllm_config.scheduler_config.max_num_seqs + if not vllm_config.model_config.enforce_eager: + max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size) + + # Initialize parent class + super().__init__(max_batch_size) + + conv_state = torch.empty(size=(num_conv_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda") + + self._lfm2_cache = conv_state + + @property + def cache(self): + return self._lfm2_cache + + def _copy_cache(self, from_index: int, to_index: int): + for cache_t in self.cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) + + def current_run_tensors(self, **kwargs) -> ConvCacheParams: + """ + Return the tensors for the current run's conv state. + """ + cache_tensor, state_indices_tensor = super().current_run_tensors( + **kwargs) + return ConvCacheParams(cache_tensor, state_indices_tensor) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the LFM2 Cache during the CUDA graph + replay runs. + """ + return self._lfm2_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device="cuda") diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py new file mode 100755 index 00000000000..5e53a821cbf --- /dev/null +++ b/vllm/model_executor/models/lfm2.py @@ -0,0 +1,577 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable +from typing import Any, Optional + +import torch +import torch.nn as nn +from transformers import LFM2Config + +from vllm import envs +from vllm.attention import Attention +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.conv import ShortConv +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.conv_cache import (ConvCacheManager, + ConvCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .interfaces import (HasInnerState, IsHybrid, + SupportsLoRA, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata + + +class LFM2MLP(nn.Module): + + def __init__( + self, + dim: int, + ff_dim: int, + multiple_of: int, + auto_adjust_ff_dim: bool, + ffn_dim_multiplier: Optional[float], + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + if auto_adjust_ff_dim: + ff_dim = int(2 * ff_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + ff_dim = int(ffn_dim_multiplier * ff_dim) + ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) + + self.w1 = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[ff_dim] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.w2 = RowParallelLinear( + input_size=ff_dim, + output_size=dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.w1(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class LFM2Attention(nn.Module): + + def __init__( + self, + config: LFM2Config, + layer_idx: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + n_tokens, _ = hidden_states.shape + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous() + k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous() + q = self.q_layernorm(q) + k = self.k_layernorm(k) + q, k = self.rotary_emb(positions, q, k) + q = q.view(n_tokens, self.num_heads * self.head_dim) + k = k.view(n_tokens, self.num_kv_heads * self.head_dim) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class LFM2AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: LFM2Config, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.prefix = prefix + self.config = config + self.layer_idx = layer_idx + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + self.self_attn = LFM2Attention( + config=config, + layer_idx=layer_idx, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.feed_forward = LFM2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + return self.feed_forward(hidden_states), residual + + +class LFM2ShortConvDecoderLayer(nn.Module): + + def __init__( + self, + config: LFM2Config, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.conv = ShortConv( + config=config, + dim=config.conv_dim, + layer_idx=layer_idx, + prefix=f"{prefix}.conv", + ) + + self.feed_forward = LFM2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + conv_cache_params: ConvCacheParams, + conv_metadata: Mamba2Metadata, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + hidden_states = self.conv( + hidden_states, + conv_cache_params=conv_cache_params, + conv_metadata=conv_metadata, + ) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class LFM2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + is_attn = layer_idx in config.full_attn_idxs + layer_class = (LFM2AttentionDecoderLayer + if is_attn else LFM2ShortConvDecoderLayer) + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.embedding_norm = RMSNorm(config.hidden_size, + eps=config.norm_eps) + else: + self.embedding_norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + conv_cache_params: ConvCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_metadata = get_forward_context().attn_metadata + + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=1, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + kv_cache_index = 0 + state_cache_index = 0 + for layer in self.layers[self.start_layer:self.end_layer]: + layer_conv_cache_params = None + if isinstance(layer, LFM2AttentionDecoderLayer): + kv_cache_index += 1 + if isinstance(layer, LFM2ShortConvDecoderLayer): + current_state_layer = state_cache_index + layer_conv_cache_params = conv_cache_params.at_layer_idx( + current_state_layer) if conv_cache_params else None + state_cache_index += 1 + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + conv_cache_params=layer_conv_cache_params, + conv_metadata=mamba2_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.embedding_norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".w1", ".w1", 0), + (".w1", ".w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class LFM2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "w1": [ + "w1", + "w3", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "LFM2 currently does not support prefix caching" + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + + self.model = LFM2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + else: + self.lm_head = PPMissingLayer() + + # Used to track and store by the Mamba cache between steps. + self.lfm2_cache: Optional[ConvCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + conv_cache_params = None + if not envs.VLLM_USE_V1: + if self.lfm2_cache is None: + num_conv_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.conv) + self.lfm2_cache = ConvCacheManager( + vllm_config=self.vllm_config, + dtype=self.lm_head.weight.dtype, + num_conv_layers=num_conv_layers, + conv_state_shape=self._get_conv_cache_shape(), + ) + + conv_cache_params = self.lfm2_cache.current_run_tensors(**kwargs) + + hidden_states = self.model(input_ids, positions, conv_cache_params, + intermediate_tensors, inputs_embeds) + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.lfm2_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.lfm2_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_conv_cache_shape(self) -> tuple[tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.conv_dim + conv_state_shape = ( + hidden_size // world_size, + self.config.conv_L_cache - 1, + ) + return conv_state_shape + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 03e45bd26d7..5e78319c15c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -82,6 +82,7 @@ "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + "LFM2ForCausalLM": ("lfm2", "LFM2ForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index cf7320a19e4..71074c55b66 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -223,6 +223,7 @@ class Device(enum.Enum): class LayerBlockType(enum.Enum): attention = "attention" mamba = "mamba" + conv = "conv" class Counter: From aaf7df1f42dd63d73b61eb013723595f15bd17c7 Mon Sep 17 00:00:00 2001 From: Paul Pak Date: Thu, 10 Jul 2025 21:47:00 +0000 Subject: [PATCH 3/8] [configs] use layer_types from huggingface hybrids >= 4.54.0.dev0 --- vllm/config.py | 6 ++++++ vllm/model_executor/models/lfm2.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 1e9d119ebf8..a4fd25be5fa 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1304,6 +1304,12 @@ def get_num_layers_by_block_type( # Hybrid model Jamba layers_block_type_value = getattr(self.hf_config, "layers_block_type", None) + + # Hybrid models in `transformers` >= 4.54.0.dev0 + if layers_block_type_value is None: + layers_block_type_value = getattr(self.hf_text_config, + "layer_types", None) + if layers_block_type_value is not None: if hasattr(self.hf_text_config, "model_type") and (self.hf_text_config.model_type diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 5e53a821cbf..709b4c2f21d 100755 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -331,7 +331,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) - is_attn = layer_idx in config.full_attn_idxs + is_attn = self.config.layer_types[layer_idx] == "full_attention" layer_class = (LFM2AttentionDecoderLayer if is_attn else LFM2ShortConvDecoderLayer) return layer_class( From 6c80cafda7067d24accd3721917881c787dd002e Mon Sep 17 00:00:00 2001 From: Paul Pak Date: Fri, 11 Jul 2025 04:59:04 +0000 Subject: [PATCH 4/8] [model_runner][v1] ShortConvSpec for ShortConv layers; compatibility with Mamba2 --- vllm/v1/attention/backends/mamba_attn.py | 22 +++++- vllm/v1/core/single_type_kv_cache_manager.py | 3 +- vllm/v1/kv_cache_interface.py | 8 ++ vllm/v1/worker/gpu_model_runner.py | 79 ++++++++++++++------ 4 files changed, 84 insertions(+), 28 deletions(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 9dea08b6583..0e23048da01 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import torch @@ -10,7 +10,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import MambaSpec +from vllm.v1.kv_cache_interface import (MambaSpec, ShortConvSpec) from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: @@ -28,6 +28,15 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: return chunk_sizes.pop() +def get_short_conv_chunk_size(vllm_config: VllmConfig) -> int: + from vllm.model_executor.layers.conv import ShortConv + layers = get_layers_from_vllm_config(vllm_config, ShortConv) + chunk_sizes = set(layer.chunk_size for layer in layers.values()) + assert len( + chunk_sizes) == 1, "All ShortConv layers must have the same chunk size" + return chunk_sizes.pop() + + def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int): @@ -97,12 +106,17 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: Union[MambaSpec, ShortConvSpec], block_table: BlockTable): self.runner = runner self.kv_cache_spec = kv_cache_spec self.block_table = block_table - self.chunk_size = get_mamba2_chunk_size(runner.vllm_config) + if isinstance(kv_cache_spec, MambaSpec): + self.chunk_size = get_mamba2_chunk_size(runner.vllm_config) + elif isinstance(kv_cache_spec, ShortConvSpec): + self.chunk_size = get_short_conv_chunk_size(runner.vllm_config) + else: + raise ValueError(f"Unsupported KV cache spec: {kv_cache_spec}") def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 5b471803807..05b5ac43033 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -8,7 +8,7 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) + MambaSpec, SlidingWindowSpec, ShortConvSpec) from vllm.v1.request import Request @@ -434,6 +434,7 @@ def allocate_new_blocks(self, request_id: str, FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, MambaSpec: MambaManager, + ShortConvSpec: MambaManager, } diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 43456a987de..dced01489ab 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -183,6 +183,14 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return self.page_size_bytes +@dataclass +class ShortConvSpec(MambaSpec): + + @property + def type_id(self) -> str: + return f"short_conv_{self.shapes}_{self.dtype}" + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9cda4dbb961..32eba6c34fb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -51,7 +51,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec, - SlidingWindowSpec) + SlidingWindowSpec, ShortConvSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata @@ -2325,7 +2325,8 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError( "Non-Attention backend is not supported by V1 " "GPUModelRunner.") - elif isinstance(kv_cache_spec, MambaSpec): + elif isinstance(kv_cache_spec, MambaSpec) or isinstance(kv_cache_spec, ShortConvSpec): + # ShortConv uses many of the same attributes, excluding chunking logic from Mamba2 attn_backend_i = Mamba2AttentionBackend else: raise ValueError( @@ -2460,8 +2461,8 @@ def _reshape_kv_cache_tensors( kv_caches[layer_name] = kv_cache_raw_tensors[ layer_name].view(dtype).view(kv_cache_shape).permute( *inv_order) - elif isinstance(kv_cache_spec, MambaSpec): - has_mamba = True + elif isinstance(kv_cache_spec, (MambaSpec, ShortConvSpec)): + has_fixed_state_layers = True raw_tensor = kv_cache_raw_tensors[layer_name] dtype = kv_cache_spec.dtype num_element_per_page = (kv_cache_spec.page_size_bytes // @@ -2485,7 +2486,7 @@ def _reshape_kv_cache_tensors( else: raise NotImplementedError - if has_attn and has_mamba: + if has_attn and has_fixed_state_layers: self._verify_hybrid_attention_mamba_layout(kv_cache_config, kv_cache_raw_tensors) @@ -2629,7 +2630,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaMixer2) - if len(mamba_layers) > 0: + short_conv_layers = get_layers_from_vllm_config(self.vllm_config, + ShortConv) + + has_mamba = len(mamba_layers) > 0 + has_conv_layer = len(short_conv_layers) > 0 + if has_mamba: if self.vllm_config.speculative_config is not None: raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") @@ -2641,9 +2647,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - page_size_padded = self._maybe_pad_mamba_page_size( - attn_layers, mamba_layers, kv_cache_spec, max_model_len, - block_size) + page_size_padded = self._maybe_pad_fixed_state_page_size( + attn_layers, mamba_layers, kv_cache_spec, MambaSpec, + max_model_len, block_size) # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. @@ -2654,31 +2660,58 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=max_model_len, page_size_padded=page_size_padded) + elif has_conv_layer: + if self.vllm_config.speculative_config is not None: + raise NotImplementedError( + "ShortConv's with speculative decoding is not supported yet.") + if not self.vllm_config.model_config.enforce_eager: + raise NotImplementedError( + "ShortConv's with cuda graph is not supported yet.") + if self.vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching is not supported for ShortConv's yet.") + max_model_len = self.vllm_config.model_config.max_model_len + + page_size_padded = self._maybe_pad_fixed_state_page_size( + attn_layers, short_conv_layers, kv_cache_spec, ShortConvSpec, + max_model_len, block_size) + + # Set block_size to max_model_len, so that mamba model will always + # have only one block in the KV cache. + for layer_name, short_conv_module in short_conv_layers.items(): + kv_cache_spec[layer_name] = ShortConvSpec( + shapes=short_conv_module.get_state_shape(), + dtype=self.kv_cache_dtype, + block_size=max_model_len, + page_size_padded=page_size_padded) + return kv_cache_spec - def _maybe_pad_mamba_page_size( + def _maybe_pad_fixed_state_page_size( self, attn_layers: dict[str, Attention], - mamba_layers: dict[str, MambaMixer2], + state_layers: dict[str, Union[MambaMixer2, ShortConv]], kv_cache_spec: dict[str, KVCacheSpec], + state_spec: type[MambaSpec | ShortConvSpec], max_model_len: int, block_size: int, ) -> Optional[int]: """ Ensure that page size of attention KV cache groups is greater than or - equal to the mamba KV cache groups. If not, we suggest to the user + equal to the Mamba/ShortConv KV cache groups. If not, we suggest to the user how to set the attention block size to ensure that it is. - If the attention page size is strictly greater than the mamba page size, - we pad the mamba page size to make them equal. + If the attention page size is strictly greater than the fixed state page size, + we pad the fixed state page size to make them equal. Args: attn_layers: Attention layers - mamba_layers: Mamba layers + state_layers: Mamba or ShortConv layers kv_cache_spec: KV cache spec (populated with attention layers) + state_spec: MambaSpec or ShortConvSpec Returns: - Optional[int]: Mamba page size with padding (None if no padding). + Optional[int]: State page size with padding (None if no padding). """ if len(attn_layers) == 0: @@ -2686,23 +2719,23 @@ def _maybe_pad_mamba_page_size( attn_layer_name = next(iter(attn_layers)) attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes - mamba_layer_name = next(iter(mamba_layers)) - mamba_page_size = MambaSpec( - shapes=mamba_layers[mamba_layer_name].get_state_shape(), + state_layer_name = next(iter(state_layers)) + state_page_size = state_spec( + shapes=state_layers[state_layer_name].get_state_shape(), dtype=self.kv_cache_dtype, block_size=max_model_len).page_size_bytes - if attn_page_size < mamba_page_size: + if attn_page_size < state_page_size: # attention page size (for 16 tokens) attn_page_size_16 = 16 * attn_page_size // block_size # some attention backends (e.g. FA) only support setting # block size to multiple of 16, so let's suggest a value # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - suggest_attn_block_size = 16 * cdiv(mamba_page_size, + # with mamba or short-conv layers, use FlashInfer instead). + suggest_attn_block_size = 16 * cdiv(state_page_size, attn_page_size_16) raise ValueError( "Attention block size should be increased to at least " f"{suggest_attn_block_size} in order to match " - "the mamba page size") + "the state page size") return attn_page_size From d17c95f70b95e44d2fa85f8aa750aef3b3622d1d Mon Sep 17 00:00:00 2001 From: Paul Pak Date: Fri, 11 Jul 2025 04:59:38 +0000 Subject: [PATCH 5/8] [configs] need to detect full_attention key in layer_types for transformers >= 4.54.0.dev0 --- vllm/config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a4fd25be5fa..dfdd83080d3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1305,7 +1305,7 @@ def get_num_layers_by_block_type( layers_block_type_value = getattr(self.hf_config, "layers_block_type", None) - # Hybrid models in `transformers` >= 4.54.0.dev0 + # NOTE(pp): Attribute for hybrid models in `transformers` >= 4.54.0.dev0 if layers_block_type_value is None: layers_block_type_value = getattr(self.hf_text_config, "layer_types", None) @@ -1319,8 +1319,10 @@ def get_num_layers_by_block_type( for t in layers_block_type_value[start:end]) else: return self.get_num_layers(parallel_config) - return sum(t == block_type.value - for t in layers_block_type_value[start:end]) + return sum( + 1 for t in layers_block_type_value[start:end] + if (t == "full_attention" and "attention" == block_type.value) or (t == block_type.value) + ) # Hybrid model Minimax attn_type_list = getattr(self.hf_config, "attn_type_list", None) From 1bc8835c7d0a1073fe9fedba6e2ec0fea100d4fe Mon Sep 17 00:00:00 2001 From: Paul Pak Date: Fri, 11 Jul 2025 05:11:57 +0000 Subject: [PATCH 6/8] [layers][conv] update ShortConv layer to be compatible with triton causal_conv1d kernel --- vllm/model_executor/layers/conv.py | 15 +++++++++++---- vllm/model_executor/models/conv_cache.py | 8 +++++--- vllm/v1/kv_cache_interface.py | 23 ++++++++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 1 + 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/conv.py b/vllm/model_executor/layers/conv.py index 7eba01f1d36..2c32a0f0c96 100644 --- a/vllm/model_executor/layers/conv.py +++ b/vllm/model_executor/layers/conv.py @@ -12,6 +12,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, + update_metadata) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.attention.backends.abstract import AttentionMetadata @@ -91,9 +93,10 @@ def forward_cuda( if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] + conv_metadata = attn_metadata assert isinstance(attn_metadata, Mamba2AttentionMetadata) self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0] + conv_state = self_kv_cache[0].transpose(-1, -2) state_indices_tensor = attn_metadata.state_indices_tensor has_initial_states_p = attn_metadata.has_initial_states # prep_initial_states = attn_metadata.prep_initial_states @@ -188,9 +191,12 @@ def forward_cuda( conv_output_list = [] if has_prefill: - Bx_p = (B_p * x_p).contiguous() + Bx_p = (B_p * x_p).transpose(0, 1) + if conv_metadata.cu_seqlen is None: + conv_metadata = update_metadata( + Bx_p, attn_metadata.query_start_loc, conv_metadata) Bx = causal_conv1d_fn( - Bx_p.transpose(0, 1), + Bx_p, conv_weights, self.conv.bias, activation=None, @@ -228,8 +234,9 @@ def forward_cuda( def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: world_size = get_tensor_model_parallel_world_size() + # contiguous along 'dim' axis conv_state_shape = ( - divide(self.conv_dim, world_size), self.L_cache - 1, + divide(self.conv_dim, world_size), ) return (conv_state_shape,) \ No newline at end of file diff --git a/vllm/model_executor/models/conv_cache.py b/vllm/model_executor/models/conv_cache.py index 0f7f62443d8..f6739c59309 100644 --- a/vllm/model_executor/models/conv_cache.py +++ b/vllm/model_executor/models/conv_cache.py @@ -36,11 +36,13 @@ def __init__( # Initialize parent class super().__init__(max_batch_size) + # Note(pp): this is for the V0 runner. + # assume conv_state = (dim, state_len). + assert conv_state_shape[0] > conv_state_shape[1] conv_state = torch.empty(size=(num_conv_layers, max_batch_size) + - conv_state_shape, + (conv_state_shape[1], conv_state_shape[0]), dtype=dtype, - device="cuda") - + device="cuda").transpose(-1, -2) self._lfm2_cache = conv_state @property diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index dced01489ab..fe673646965 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -184,12 +184,33 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass -class ShortConvSpec(MambaSpec): +class ShortConvSpec(KVCacheSpec): + """Nearly identical to MambaSpec above. """ + shapes: tuple[tuple[int, ...], ...] + dtype: torch.dtype + page_size_padded: Optional[int] = None + + def __post_init__(self): + self.num_elements = sum(prod(shape) for shape in self.shapes) @property def type_id(self) -> str: return f"short_conv_{self.shapes}_{self.dtype}" + @property + def page_size_bytes(self) -> int: + page_size = self.num_elements * get_dtype_size(self.dtype) + if self.page_size_padded is not None: + assert self.page_size_padded >= page_size + return self.page_size_padded + return page_size + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # We allocate 1 block for each request now, so max_memory_usage_bytes is + # the same as page_size_bytes. + # Need to update this when supporting prefix caching. + return self.page_size_bytes + @dataclass class KVCacheTensor: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 32eba6c34fb..dc2ad228f30 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,6 +31,7 @@ set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.conv import ShortConv from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (has_step_pooler, From e550362bcd50c5831f46e995d1d5a048d28cc918 Mon Sep 17 00:00:00 2001 From: Paul Pak Date: Fri, 11 Jul 2025 05:20:04 +0000 Subject: [PATCH 7/8] [transformers][ovis] tmp: AIMv2Config doesn't need to be registered on 4.54.0.dev0 --- vllm/transformers_utils/configs/ovis.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py index c2728f0ed64..5d6f69fcd2c 100644 --- a/vllm/transformers_utils/configs/ovis.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -73,7 +73,10 @@ def __init__( IMAGE_ATOM_ID = -300 IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] -AutoConfig.register("aimv2", AIMv2Config) +try: + AutoConfig.register("aimv2", AIMv2Config) +except Exception as e: + pass # ---------------------------------------------------------------------- From 05af65a78bc9380d967c84f6580071d06f7b03ad Mon Sep 17 00:00:00 2001 From: Paul Pak Date: Fri, 11 Jul 2025 16:17:00 +0000 Subject: [PATCH 8/8] [models][lfm2] LFM2->Lfm2 to match config --- vllm/model_executor/models/conv_cache.py | 2 +- vllm/model_executor/models/lfm2.py | 38 ++++++++++++------------ vllm/model_executor/models/registry.py | 2 +- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/conv_cache.py b/vllm/model_executor/models/conv_cache.py index f6739c59309..b4a70b1bbad 100644 --- a/vllm/model_executor/models/conv_cache.py +++ b/vllm/model_executor/models/conv_cache.py @@ -65,7 +65,7 @@ def current_run_tensors(self, **kwargs) -> ConvCacheParams: def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the LFM2 Cache during the CUDA graph + The buffer is used to maintain the Lfm2 Cache during the CUDA graph replay runs. """ return self._lfm2_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size, diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 709b4c2f21d..fda1f9b799f 100755 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from transformers import LFM2Config +from transformers import Lfm2Config from vllm import envs from vllm.attention import Attention @@ -44,7 +44,7 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata -class LFM2MLP(nn.Module): +class Lfm2MLP(nn.Module): def __init__( self, @@ -87,11 +87,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class LFM2Attention(nn.Module): +class Lfm2Attention(nn.Module): def __init__( self, - config: LFM2Config, + config: Lfm2Config, layer_idx: int, hidden_size: int, num_heads: int, @@ -184,11 +184,11 @@ def forward( return output -class LFM2AttentionDecoderLayer(nn.Module): +class Lfm2AttentionDecoderLayer(nn.Module): def __init__( self, - config: LFM2Config, + config: Lfm2Config, layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -208,7 +208,7 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.self_attn = LFM2Attention( + self.self_attn = Lfm2Attention( config=config, layer_idx=layer_idx, hidden_size=config.hidden_size, @@ -222,7 +222,7 @@ def __init__( prefix=f"{prefix}.self_attn", ) - self.feed_forward = LFM2MLP( + self.feed_forward = Lfm2MLP( dim=config.block_dim, ff_dim=config.block_ff_dim, multiple_of=config.block_multiple_of, @@ -253,11 +253,11 @@ def forward( return self.feed_forward(hidden_states), residual -class LFM2ShortConvDecoderLayer(nn.Module): +class Lfm2ShortConvDecoderLayer(nn.Module): def __init__( self, - config: LFM2Config, + config: Lfm2Config, layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -272,7 +272,7 @@ def __init__( prefix=f"{prefix}.conv", ) - self.feed_forward = LFM2MLP( + self.feed_forward = Lfm2MLP( dim=config.block_dim, ff_dim=config.block_ff_dim, multiple_of=config.block_multiple_of, @@ -308,7 +308,7 @@ def forward( return hidden_states, residual -class LFM2Model(nn.Module): +class Lfm2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -332,8 +332,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) is_attn = self.config.layer_types[layer_idx] == "full_attention" - layer_class = (LFM2AttentionDecoderLayer - if is_attn else LFM2ShortConvDecoderLayer) + layer_class = (Lfm2AttentionDecoderLayer + if is_attn else Lfm2ShortConvDecoderLayer) return layer_class( config, layer_idx, @@ -391,9 +391,9 @@ def forward( state_cache_index = 0 for layer in self.layers[self.start_layer:self.end_layer]: layer_conv_cache_params = None - if isinstance(layer, LFM2AttentionDecoderLayer): + if isinstance(layer, Lfm2AttentionDecoderLayer): kv_cache_index += 1 - if isinstance(layer, LFM2ShortConvDecoderLayer): + if isinstance(layer, Lfm2ShortConvDecoderLayer): current_state_layer = state_cache_index layer_conv_cache_params = conv_cache_params.at_layer_idx( current_state_layer) if conv_cache_params else None @@ -449,7 +449,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class LFM2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, +class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant): packed_modules_mapping = { "qkv_proj": [ @@ -477,7 +477,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config assert (not cache_config.enable_prefix_caching - ), "LFM2 currently does not support prefix caching" + ), "Lfm2 currently does not support prefix caching" super().__init__() self.config = config @@ -485,7 +485,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.scheduler_config = scheduler_config self.model_config = vllm_config.model_config - self.model = LFM2Model(vllm_config=vllm_config, + self.model = Lfm2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5e78319c15c..cb046744b2d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -82,7 +82,7 @@ "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), - "LFM2ForCausalLM": ("lfm2", "LFM2ForCausalLM"), + "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"),