From d7da3136c3effe2bc7e536742a8bbf8c6767a9d3 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Sun, 9 Mar 2025 18:28:22 -0300 Subject: [PATCH 1/2] First working PoC for bge-m3 sparse embeddings Here I'm loading the extra `sparse_linear.pt` file using the secondary_weights loading introduced in the ultravox model when I detect that the model name is `BAAI/bge-m3`. It's a bit ugly but I don't know if there is a more generic way to do this. Currently, since the only permissible pooling return type is torch.tensor, I'm just returning the token weights tensor directly. If the use wants to match tokens to the weights they have to call `tokenize` and remove the bos and eos token and then the indices of both vectors should match. To request sparse vectors the use has to pass "additional_data": {"sparse_embeddings": true} in the request. This means that all sequences in that request will be treated as sparse. If the user wants to mix, separate calls have to be made for each type of embedding. The FlagEmbedding API allows to return more then one type of embedding at the same time, but currently, due to the limitation of the pooling return type we can only return a single tensor per sequence. Signed-off-by: Max de Bayser --- vllm/model_executor/models/roberta.py | 158 +++++++++++++++++++++++++- 1 file changed, 154 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index ba92eef12707..d7945dbad4a9 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,20 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 import itertools -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import RobertaConfig -from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import CrossEncodingPooler +from vllm.config import PoolerConfig, VllmConfig +from vllm.model_executor.layers.pooler import (AllPool, CrossEncodingPooler, + Pooler, PoolingType) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingParams from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) @@ -160,6 +163,105 @@ def forward(self, features, **kwargs): return x +def filter_secondary_weights( + all_weights: Iterable[Tuple[str, torch.Tensor]], + secondary_weights: list[str], +) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str, + torch.Tensor]]]: + all_weights1, all_weights2 = itertools.tee(all_weights) + + def filtered(n): + return any(n.startswith(f) for f in secondary_weights) + + return ((n, w) for n, w in all_weights1 if filtered(n)), \ + ((n, w) for n, w in all_weights2 if not filtered(n)) + + +class M3SparsePooler(AllPool): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use. + normalize: Whether to normalize the pooled data. + """ + + def __init__(self, dense_pooler: Pooler, sparse_linear: nn.Module, + bos_token_id: int, eos_token_id: int) -> None: + super().__init__(normalize=False, softmax=False) + self.dense_pooler = dense_pooler + self.sparse_linear = sparse_linear + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + + seq_ids = [] + is_sparse = [] + + for i, (group, pool_param) in enumerate(pooling_metadata.seq_groups): + seq_ids.append(group[0]) + ad = pool_param.additional_data + is_sparse.append(ad is not None + and ad.get("sparse_embeddings", False)) + + if not any(is_sparse): + return self.dense_pooler(hidden_states, pooling_metadata) + else: + + split_hidden_states = self.extract_states(hidden_states, + pooling_metadata) + dense_hidden_states = [] + + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + prompt_lens: List[int] = [] + + for i, t in enumerate(split_hidden_states): + if not is_sparse[i]: + dense_hidden_states.append(t) + seq_groups.append(pooling_metadata.seq_groups[i]) + prompt_lens.append(pooling_metadata.prompt_lens[i]) + + dense_output = [] + + if dense_hidden_states: + dense_hidden_states = torch.cat(dense_hidden_states) + dense_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=pooling_metadata.seq_data, + prompt_lens=prompt_lens) + dense_output = self.dense_pooler(dense_hidden_states, + dense_metadata).outputs + + dense_it = iter(dense_output) + + pooled_outputs = [] + + for i, hidden_states in enumerate(split_hidden_states): + if is_sparse[i]: + pooled_data = torch.squeeze( + torch.relu(self.sparse_linear(hidden_states))) + token_ids = pooling_metadata.seq_data[ + seq_ids[i]].prompt_token_ids + if token_ids[0] == self.bos_token_id: + pooled_data = pooled_data[1:] + if token_ids[-1] == self.eos_token_id: + pooled_data = pooled_data[:-1] + pooled_outputs.append(self.build_output(pooled_data)) + else: + pooled_outputs.append(next(dense_it)) + + return PoolerOutput(outputs=pooled_outputs) + + class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. @@ -171,6 +273,28 @@ class RobertaEmbeddingModel(BertEmbeddingModel): _pooler: An instance of Pooler used for pooling operations. """ + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + self.is_m3 = vllm_config.model_config.model == "BAAI/bge-m3" + self.hidden_size = vllm_config.model_config.hf_config.hidden_size + + self.bos_token_id = vllm_config.model_config.hf_config.bos_token_id + self.eos_token_id = vllm_config.model_config.hf_config.eos_token_id + + super().__init__(vllm_config=vllm_config, prefix=prefix) + + self.secondary_weights = [] + self.secondary_weight_names = [] + + if self.is_m3: + self.secondary_weight_names.append("sparse_linear.") + self.secondary_weights.append( + DefaultModelLoader.Source( + model_or_path=vllm_config.model_config.model, + revision=None, + prefix="sparse_linear.", + allow_patterns_overrides=["sparse_linear.pt"])) + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel: @@ -178,7 +302,24 @@ def _build_model(self, prefix=prefix, embedding_class=RobertaEmbedding) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + dense_pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.CLS, + normalize=True, + softmax=False) + if not self.is_m3: + return dense_pooler + else: + self.sparse_linear = nn.Linear(self.hidden_size, 1) + return M3SparsePooler(dense_pooler, self.sparse_linear, + self.bos_token_id, self.eos_token_id) + + def load_weights(self, all_weights: Iterable[Tuple[str, torch.Tensor]]): + + secondary, weights = filter_secondary_weights( + all_weights, self.secondary_weight_names) + weights = self.hf_to_vllm_mapper.apply(weights) # Separate weights in "roberta"-prefixed and all else (not in memory). # For use with models like FacebookAI/roberta-base. @@ -190,6 +331,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded = self.model.load_weights(task_weights) assert len(loaded), "Unable to load RobertaEmbeddingModel" + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in secondary: + if name.startswith("sparse_linear"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsV0Only): From c73dbcdf59098f3604a65f679275c0e8732ab692 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Thu, 13 Mar 2025 11:28:26 -0300 Subject: [PATCH 2/2] Move the M3 support to a separate class This is cleaner and can be activated by the user by setting `--hf-overrides '{"architectures": ["BgeM3EmbeddingModel"]}'` Signed-off-by: Max de Bayser --- vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/roberta.py | 111 +++++++++++++------------ 2 files changed, 60 insertions(+), 52 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5dd3aa2973cd..d8420d2899ad 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -116,6 +116,7 @@ "RobertaModel": ("roberta", "RobertaEmbeddingModel"), "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), + "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index d7945dbad4a9..52853274e560 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -163,6 +163,37 @@ def forward(self, features, **kwargs): return x +class RobertaEmbeddingModel(BertEmbeddingModel): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + return BertModel(vllm_config=vllm_config, + prefix=prefix, + embedding_class=RobertaEmbedding) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.hf_to_vllm_mapper.apply(weights) + # Separate weights in "roberta"-prefixed and all else (not in memory). + # For use with models like FacebookAI/roberta-base. + bert_weights, task_weights = roberta_task_weights_filter(weights) + loaded = self.model.load_weights(bert_weights) + if not len(loaded): + # Fix for models like `sentence-transformers/stsb-roberta-base-v2` + # which use the same architecture, but have no "roberta" prefix. + loaded = self.model.load_weights(task_weights) + assert len(loaded), "Unable to load RobertaEmbeddingModel" + + def filter_secondary_weights( all_weights: Iterable[Tuple[str, torch.Tensor]], secondary_weights: list[str], @@ -178,16 +209,20 @@ def filtered(n): class M3SparsePooler(AllPool): - """A layer that pools specific information from hidden states. + """A pooler that implements M3 sparse pooling This layer does the following: - 1. Extracts specific tokens or aggregates data based on pooling method. - 2. Normalizes output if specified. - 3. Returns structured results as `PoolerOutput`. + 1. By default returns dense embeddings. + 2. If the pooling params "additional_data" contain + "sparse_embeddings", return sparse embeddings Attributes: - pooling_type: The type of pooling to use. - normalize: Whether to normalize the pooled data. + dense_pooler: The default pooler. + sparse_linear: the linear module applied to the + logits to obtain the token weights + bos_token_id and eos_token_id: The special tokens + inserted by the tokenizer. These are removed for + sparse embeddings """ def __init__(self, dense_pooler: Pooler, sparse_linear: nn.Module, @@ -262,45 +297,30 @@ def forward( return PoolerOutput(outputs=pooled_outputs) -class RobertaEmbeddingModel(BertEmbeddingModel): - """A model that uses Roberta to provide embedding functionalities. - - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. +class BgeM3EmbeddingModel(RobertaEmbeddingModel): + """A model that extends RobertaEmbeddingModel with sparse embeddings. - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. + This class supports loading an additional sparse_linear.pt file + to create sparse embeddings as described in https://arxiv.org/abs/2402.03216 """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - self.is_m3 = vllm_config.model_config.model == "BAAI/bge-m3" self.hidden_size = vllm_config.model_config.hf_config.hidden_size self.bos_token_id = vllm_config.model_config.hf_config.bos_token_id self.eos_token_id = vllm_config.model_config.hf_config.eos_token_id super().__init__(vllm_config=vllm_config, prefix=prefix) + self.secondary_weight_prefix = "sparse_linear." - self.secondary_weights = [] - self.secondary_weight_names = [] - - if self.is_m3: - self.secondary_weight_names.append("sparse_linear.") - self.secondary_weights.append( - DefaultModelLoader.Source( - model_or_path=vllm_config.model_config.model, - revision=None, - prefix="sparse_linear.", - allow_patterns_overrides=["sparse_linear.pt"])) - - def _build_model(self, - vllm_config: VllmConfig, - prefix: str = "") -> BertModel: - return BertModel(vllm_config=vllm_config, - prefix=prefix, - embedding_class=RobertaEmbedding) + self.secondary_weights = [ + DefaultModelLoader.Source( + model_or_path=vllm_config.model_config.model, + revision=None, + prefix=self.secondary_weight_prefix, + allow_patterns_overrides=["sparse_linear.pt"]) + ] def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: dense_pooler = Pooler.from_config_with_defaults( @@ -308,33 +328,20 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: pooling_type=PoolingType.CLS, normalize=True, softmax=False) - if not self.is_m3: - return dense_pooler - else: - self.sparse_linear = nn.Linear(self.hidden_size, 1) - return M3SparsePooler(dense_pooler, self.sparse_linear, - self.bos_token_id, self.eos_token_id) + self.sparse_linear = nn.Linear(self.hidden_size, 1) + return M3SparsePooler(dense_pooler, self.sparse_linear, + self.bos_token_id, self.eos_token_id) def load_weights(self, all_weights: Iterable[Tuple[str, torch.Tensor]]): - secondary, weights = filter_secondary_weights( - all_weights, self.secondary_weight_names) + all_weights, [self.secondary_weight_prefix]) - weights = self.hf_to_vllm_mapper.apply(weights) - # Separate weights in "roberta"-prefixed and all else (not in memory). - # For use with models like FacebookAI/roberta-base. - bert_weights, task_weights = roberta_task_weights_filter(weights) - loaded = self.model.load_weights(bert_weights) - if not len(loaded): - # Fix for models like `sentence-transformers/stsb-roberta-base-v2` - # which use the same architecture, but have no "roberta" prefix. - loaded = self.model.load_weights(task_weights) - assert len(loaded), "Unable to load RobertaEmbeddingModel" + super().load_weights(weights) params_dict = dict(self.named_parameters()) for name, loaded_weight in secondary: - if name.startswith("sparse_linear"): + if name.startswith(self.secondary_weight_prefix): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)