Skip to content

LFM2 #20797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ endif()
#
find_package(Torch REQUIRED)

#
# Ignore nvToolsExt for cuda-12.9
#
if (NOT TARGET CUDA::nvToolsExt)
add_library(CUDA::nvToolsExt INTERFACE IMPORTED)
endif()

# Supported NVIDIA architectures.
# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
Expand Down
12 changes: 10 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,12 @@ def get_num_layers_by_block_type(
# Hybrid model Jamba
layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None)

# NOTE(pp): Attribute for hybrid models in `transformers` >= 4.54.0.dev0
if layers_block_type_value is None:
layers_block_type_value = getattr(self.hf_text_config,
"layer_types", None)

if layers_block_type_value is not None:
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
Expand All @@ -1313,8 +1319,10 @@ def get_num_layers_by_block_type(
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value
for t in layers_block_type_value[start:end])
return sum(
1 for t in layers_block_type_value[start:end]
if (t == "full_attention" and "attention" == block_type.value) or (t == block_type.value)
)
Comment on lines +1322 to +1325
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This conditional logic for counting layer types is a bit complex and hard to read. It can be simplified by handling the special case for attention layers separately, which would make the code more readable and easier to maintain.

Suggested change
return sum(
1 for t in layers_block_type_value[start:end]
if (t == "full_attention" and "attention" == block_type.value) or (t == block_type.value)
)
if block_type == LayerBlockType.attention:
return sum(t in ("attention", "full_attention")
for t in layers_block_type_value[start:end])
return sum(t == block_type.value
for t in layers_block_type_value[start:end])


# Hybrid model Minimax
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
Expand Down
242 changes: 242 additions & 0 deletions vllm/model_executor/layers/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@

from typing import Any, Optional

import torch
import torch.nn as nn

from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.distributed import divide, get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.model_executor.models.conv_cache import ConvCacheParams
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata


@CustomOp.register("short_conv")
class ShortConv(CustomOp):

def __init__(self, config, dim: int, layer_idx: int, prefix: str = ""):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.conv_dim = dim
self.L_cache = config.conv_L_cache
self.bias = config.conv_bias

self.conv = ColumnParallelLinear(
input_size=self.L_cache,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv.weight.data = self.conv.weight.data.unsqueeze(1)

self.in_proj = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[dim] * 3,
bias=self.bias,
prefix=f"{prefix}.in_proj",
)
self.out_proj = RowParallelLinear(
input_size=dim,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.out_proj",
)

if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state,)
self.kv_cache = [(torch.tensor([]))]

# For compatibility with MambaSpec utils
self.chunk_size = 1
self.prefix = prefix

def forward_native(self, hidden_states: torch.Tensor,
conv_cache_params: ConvCacheParams) -> torch.Tensor:
pass

def forward_cuda(
self,
hidden_states: torch.Tensor,
conv_cache_params: ConvCacheParams,
conv_metadata: Mamba2Metadata,
) -> torch.Tensor:
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: Optional[AttentionMetadata] = get_forward_context().attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
# prep_initial_states = attn_metadata.prep_initial_states
# chunk_size = attn_metadata.chunk_size
# seq_idx_p = attn_metadata.seq_idx
# chunk_indices_p = attn_metadata.chunk_indices
# chunk_offsets_p = attn_metadata.chunk_offsets
else:
conv_state = conv_cache_params.conv_state
state_indices_tensor = conv_cache_params.state_indices_tensor
has_initial_states_p = conv_metadata.has_initial_states
# prep_initial_states = conv_metadata.prep_initial_states
# chunk_size = conv_metadata.chunk_size
# seq_idx_p = conv_metadata.seq_idx
# chunk_indices_p = conv_metadata.chunk_indices
# chunk_offsets_p = conv_metadata.chunk_offsets

BCx, _ = self.in_proj(hidden_states)

B, C, x = BCx.chunk(3, dim=-1)

conv_weights = self.conv.weight.view(self.conv.weight.size(0),
self.conv.weight.size(2))

if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
Bx = (B * x).contiguous()
hidden_states = C * Bx
contextualized_states, _ = self.out_proj(hidden_states)
return contextualized_states

num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable name num_decodes is a bit misleading as it stores the number of decode tokens, not requests. In decode phase, this is usually 1 token per request, but for clarity and consistency with num_prefill_tokens, it would be better to name it num_decode_tokens. This would improve readability for future maintainers.

Suggested change
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_decode_tokens = attn_metadata.num_decode_tokens # token count (=request)

num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0

# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if envs.VLLM_USE_V1:
B_d, B_p = torch.split(
B,
[num_decodes, num_prefill_tokens],
dim=0,
)
C_d, C_p = torch.split(
C,
[num_decodes, num_prefill_tokens],
dim=0,
)
x_d, x_p = torch.split(
x,
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
B_p, B_d = torch.split(
B,
[num_prefill_tokens, num_decodes],
dim=0,
)
C_p, C_d = torch.split(
C,
[num_prefill_tokens, num_decodes],
dim=0,
)
x_p, x_d = torch.split(
x,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
conv_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)

conv_output_list = []

if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(
Bx_p, attn_metadata.query_start_loc, conv_metadata)
Bx = causal_conv1d_fn(
Bx_p,
conv_weights,
self.conv.bias,
activation=None,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]

C_p = C_p.view(1, num_prefill_tokens, -1)
y = C_p * Bx
conv_output_list.append(y.view(num_prefill_tokens, -1))

if has_decode:
Bx_d = (B_d * x_d).contiguous()
Bx = causal_conv1d_update(
Bx_d,
conv_state,
conv_weights,
self.conv.bias,
activation=None,
conv_state_indices=state_indices_tensor_d)
C_d = C_d.view(num_decodes, -1)
y = C_d * Bx
conv_output_list.append(y.view(num_decodes, -1))

# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(conv_output_list)

# Final linear projection
contextualized_states, _ = self.out_proj(hidden_states)

return contextualized_states


def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint tuple[tuple[int, ...], tuple[int, ...]] indicates a tuple containing two tuples of integers. However, the function returns a tuple containing only one tuple: (conv_state_shape,).

To match the implementation and the expected usage with MambaSpec-like structures, the type hint should be tuple[tuple[int, ...], ...], which correctly represents a tuple containing one or more tuples of integers.

Suggested change
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
def get_state_shape(self) -> tuple[tuple[int, ...], ...]:

world_size = get_tensor_model_parallel_world_size()
# contiguous along 'dim' axis
conv_state_shape = (
self.L_cache - 1,
divide(self.conv_dim, world_size),
)
return (conv_state_shape,)
73 changes: 73 additions & 0 deletions vllm/model_executor/models/conv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass

import torch

from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache


@dataclass
class ConvCacheParams:
conv_state: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()

def at_layer_idx(self, layer_idx):
return ConvCacheParams(self.conv_state[layer_idx],
self.state_indices_tensor)


class ConvCacheManager(ConstantSizeCache):

def __init__(
self,
vllm_config: VllmConfig,
dtype: torch.dtype,
num_conv_layers: int,
conv_state_shape: tuple[int, int]):

max_batch_size = vllm_config.scheduler_config.max_num_seqs
if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)

# Initialize parent class
super().__init__(max_batch_size)

# Note(pp): this is for the V0 runner.
# assume conv_state = (dim, state_len).
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.empty(size=(num_conv_layers, max_batch_size) +
(conv_state_shape[1], conv_state_shape[0]),
dtype=dtype,
device="cuda").transpose(-1, -2)
self._lfm2_cache = conv_state

@property
def cache(self):
return self._lfm2_cache

def _copy_cache(self, from_index: int, to_index: int):
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)

def current_run_tensors(self, **kwargs) -> ConvCacheParams:
"""
Return the tensors for the current run's conv state.
"""
cache_tensor, state_indices_tensor = super().current_run_tensors(
**kwargs)
return ConvCacheParams(cache_tensor, state_indices_tensor)

def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Lfm2 Cache during the CUDA graph
replay runs.
"""
return self._lfm2_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
Loading