Skip to content

First working PoC for bge-m3 sparse embeddings #14526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 154 additions & 4 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# SPDX-License-Identifier: Apache-2.0

import itertools
from typing import Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers import RobertaConfig

from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.config import PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import (AllPool, CrossEncodingPooler,
Pooler, PoolingType)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
Expand Down Expand Up @@ -160,6 +163,105 @@ def forward(self, features, **kwargs):
return x


def filter_secondary_weights(
all_weights: Iterable[Tuple[str, torch.Tensor]],
secondary_weights: list[str],
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
torch.Tensor]]]:
all_weights1, all_weights2 = itertools.tee(all_weights)

def filtered(n):
return any(n.startswith(f) for f in secondary_weights)

return ((n, w) for n, w in all_weights1 if filtered(n)), \
((n, w) for n, w in all_weights2 if not filtered(n))


class M3SparsePooler(AllPool):
"""A layer that pools specific information from hidden states.

This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.

Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""

def __init__(self, dense_pooler: Pooler, sparse_linear: nn.Module,
bos_token_id: int, eos_token_id: int) -> None:
super().__init__(normalize=False, softmax=False)
self.dense_pooler = dense_pooler
self.sparse_linear = sparse_linear
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:

seq_ids = []
is_sparse = []

for i, (group, pool_param) in enumerate(pooling_metadata.seq_groups):
seq_ids.append(group[0])
ad = pool_param.additional_data
is_sparse.append(ad is not None
and ad.get("sparse_embeddings", False))

if not any(is_sparse):
return self.dense_pooler(hidden_states, pooling_metadata)
else:

split_hidden_states = self.extract_states(hidden_states,
pooling_metadata)
dense_hidden_states = []

seq_groups: List[Tuple[List[int], PoolingParams]] = []
prompt_lens: List[int] = []

for i, t in enumerate(split_hidden_states):
if not is_sparse[i]:
dense_hidden_states.append(t)
seq_groups.append(pooling_metadata.seq_groups[i])
prompt_lens.append(pooling_metadata.prompt_lens[i])

dense_output = []

if dense_hidden_states:
dense_hidden_states = torch.cat(dense_hidden_states)
dense_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=pooling_metadata.seq_data,
prompt_lens=prompt_lens)
dense_output = self.dense_pooler(dense_hidden_states,
dense_metadata).outputs

dense_it = iter(dense_output)

pooled_outputs = []

for i, hidden_states in enumerate(split_hidden_states):
if is_sparse[i]:
pooled_data = torch.squeeze(
torch.relu(self.sparse_linear(hidden_states)))
token_ids = pooling_metadata.seq_data[
seq_ids[i]].prompt_token_ids
if token_ids[0] == self.bos_token_id:
pooled_data = pooled_data[1:]
if token_ids[-1] == self.eos_token_id:
pooled_data = pooled_data[:-1]
pooled_outputs.append(self.build_output(pooled_data))
else:
pooled_outputs.append(next(dense_it))

return PoolerOutput(outputs=pooled_outputs)


class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.

Expand All @@ -171,14 +273,53 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.is_m3 = vllm_config.model_config.model == "BAAI/bge-m3"
self.hidden_size = vllm_config.model_config.hf_config.hidden_size

self.bos_token_id = vllm_config.model_config.hf_config.bos_token_id
self.eos_token_id = vllm_config.model_config.hf_config.eos_token_id

super().__init__(vllm_config=vllm_config, prefix=prefix)

self.secondary_weights = []
self.secondary_weight_names = []

if self.is_m3:
self.secondary_weight_names.append("sparse_linear.")
self.secondary_weights.append(
DefaultModelLoader.Source(
model_or_path=vllm_config.model_config.model,
revision=None,
prefix="sparse_linear.",
allow_patterns_overrides=["sparse_linear.pt"]))

def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
dense_pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
if not self.is_m3:
return dense_pooler
else:
self.sparse_linear = nn.Linear(self.hidden_size, 1)
return M3SparsePooler(dense_pooler, self.sparse_linear,
self.bos_token_id, self.eos_token_id)

def load_weights(self, all_weights: Iterable[Tuple[str, torch.Tensor]]):

secondary, weights = filter_secondary_weights(
all_weights, self.secondary_weight_names)

weights = self.hf_to_vllm_mapper.apply(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
Expand All @@ -190,6 +331,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loaded = self.model.load_weights(task_weights)
assert len(loaded), "Unable to load RobertaEmbeddingModel"

params_dict = dict(self.named_parameters())

for name, loaded_weight in secondary:
if name.startswith("sparse_linear"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)


class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsV0Only):
Expand Down