diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5dd3aa2973cd..d8420d2899ad 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -116,6 +116,7 @@ "RobertaModel": ("roberta", "RobertaEmbeddingModel"), "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), + "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index ba92eef12707..52853274e560 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,20 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 import itertools -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import RobertaConfig -from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import CrossEncodingPooler +from vllm.config import PoolerConfig, VllmConfig +from vllm.model_executor.layers.pooler import (AllPool, CrossEncodingPooler, + Pooler, PoolingType) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingParams from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) @@ -191,6 +194,160 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): assert len(loaded), "Unable to load RobertaEmbeddingModel" +def filter_secondary_weights( + all_weights: Iterable[Tuple[str, torch.Tensor]], + secondary_weights: list[str], +) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str, + torch.Tensor]]]: + all_weights1, all_weights2 = itertools.tee(all_weights) + + def filtered(n): + return any(n.startswith(f) for f in secondary_weights) + + return ((n, w) for n, w in all_weights1 if filtered(n)), \ + ((n, w) for n, w in all_weights2 if not filtered(n)) + + +class M3SparsePooler(AllPool): + """A pooler that implements M3 sparse pooling + + This layer does the following: + 1. By default returns dense embeddings. + 2. If the pooling params "additional_data" contain + "sparse_embeddings", return sparse embeddings + + Attributes: + dense_pooler: The default pooler. + sparse_linear: the linear module applied to the + logits to obtain the token weights + bos_token_id and eos_token_id: The special tokens + inserted by the tokenizer. These are removed for + sparse embeddings + """ + + def __init__(self, dense_pooler: Pooler, sparse_linear: nn.Module, + bos_token_id: int, eos_token_id: int) -> None: + super().__init__(normalize=False, softmax=False) + self.dense_pooler = dense_pooler + self.sparse_linear = sparse_linear + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + + seq_ids = [] + is_sparse = [] + + for i, (group, pool_param) in enumerate(pooling_metadata.seq_groups): + seq_ids.append(group[0]) + ad = pool_param.additional_data + is_sparse.append(ad is not None + and ad.get("sparse_embeddings", False)) + + if not any(is_sparse): + return self.dense_pooler(hidden_states, pooling_metadata) + else: + + split_hidden_states = self.extract_states(hidden_states, + pooling_metadata) + dense_hidden_states = [] + + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + prompt_lens: List[int] = [] + + for i, t in enumerate(split_hidden_states): + if not is_sparse[i]: + dense_hidden_states.append(t) + seq_groups.append(pooling_metadata.seq_groups[i]) + prompt_lens.append(pooling_metadata.prompt_lens[i]) + + dense_output = [] + + if dense_hidden_states: + dense_hidden_states = torch.cat(dense_hidden_states) + dense_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=pooling_metadata.seq_data, + prompt_lens=prompt_lens) + dense_output = self.dense_pooler(dense_hidden_states, + dense_metadata).outputs + + dense_it = iter(dense_output) + + pooled_outputs = [] + + for i, hidden_states in enumerate(split_hidden_states): + if is_sparse[i]: + pooled_data = torch.squeeze( + torch.relu(self.sparse_linear(hidden_states))) + token_ids = pooling_metadata.seq_data[ + seq_ids[i]].prompt_token_ids + if token_ids[0] == self.bos_token_id: + pooled_data = pooled_data[1:] + if token_ids[-1] == self.eos_token_id: + pooled_data = pooled_data[:-1] + pooled_outputs.append(self.build_output(pooled_data)) + else: + pooled_outputs.append(next(dense_it)) + + return PoolerOutput(outputs=pooled_outputs) + + +class BgeM3EmbeddingModel(RobertaEmbeddingModel): + """A model that extends RobertaEmbeddingModel with sparse embeddings. + + This class supports loading an additional sparse_linear.pt file + to create sparse embeddings as described in https://arxiv.org/abs/2402.03216 + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + self.hidden_size = vllm_config.model_config.hf_config.hidden_size + + self.bos_token_id = vllm_config.model_config.hf_config.bos_token_id + self.eos_token_id = vllm_config.model_config.hf_config.eos_token_id + + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.secondary_weight_prefix = "sparse_linear." + + self.secondary_weights = [ + DefaultModelLoader.Source( + model_or_path=vllm_config.model_config.model, + revision=None, + prefix=self.secondary_weight_prefix, + allow_patterns_overrides=["sparse_linear.pt"]) + ] + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + dense_pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.CLS, + normalize=True, + softmax=False) + self.sparse_linear = nn.Linear(self.hidden_size, 1) + return M3SparsePooler(dense_pooler, self.sparse_linear, + self.bos_token_id, self.eos_token_id) + + def load_weights(self, all_weights: Iterable[Tuple[str, torch.Tensor]]): + secondary, weights = filter_secondary_weights( + all_weights, [self.secondary_weight_prefix]) + + super().load_weights(weights) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in secondary: + if name.startswith(self.secondary_weight_prefix): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsV0Only): """A model that uses Roberta to provide embedding functionalities.