Skip to content

First working PoC for bge-m3 sparse embeddings #14526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
Expand Down
163 changes: 160 additions & 3 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# SPDX-License-Identifier: Apache-2.0

import itertools
from typing import Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers import RobertaConfig

from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.config import PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import (AllPool, CrossEncodingPooler,
Pooler, PoolingType)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
Expand Down Expand Up @@ -191,6 +194,160 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
assert len(loaded), "Unable to load RobertaEmbeddingModel"


def filter_secondary_weights(
all_weights: Iterable[Tuple[str, torch.Tensor]],
secondary_weights: list[str],
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
torch.Tensor]]]:
all_weights1, all_weights2 = itertools.tee(all_weights)

def filtered(n):
return any(n.startswith(f) for f in secondary_weights)

return ((n, w) for n, w in all_weights1 if filtered(n)), \
((n, w) for n, w in all_weights2 if not filtered(n))


class M3SparsePooler(AllPool):
"""A pooler that implements M3 sparse pooling

This layer does the following:
1. By default returns dense embeddings.
2. If the pooling params "additional_data" contain
"sparse_embeddings", return sparse embeddings

Attributes:
dense_pooler: The default pooler.
sparse_linear: the linear module applied to the
logits to obtain the token weights
bos_token_id and eos_token_id: The special tokens
inserted by the tokenizer. These are removed for
sparse embeddings
"""

def __init__(self, dense_pooler: Pooler, sparse_linear: nn.Module,
bos_token_id: int, eos_token_id: int) -> None:
super().__init__(normalize=False, softmax=False)
self.dense_pooler = dense_pooler
self.sparse_linear = sparse_linear
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:

seq_ids = []
is_sparse = []

for i, (group, pool_param) in enumerate(pooling_metadata.seq_groups):
seq_ids.append(group[0])
ad = pool_param.additional_data
is_sparse.append(ad is not None
and ad.get("sparse_embeddings", False))

if not any(is_sparse):
return self.dense_pooler(hidden_states, pooling_metadata)
else:

split_hidden_states = self.extract_states(hidden_states,
pooling_metadata)
dense_hidden_states = []

seq_groups: List[Tuple[List[int], PoolingParams]] = []
prompt_lens: List[int] = []

for i, t in enumerate(split_hidden_states):
if not is_sparse[i]:
dense_hidden_states.append(t)
seq_groups.append(pooling_metadata.seq_groups[i])
prompt_lens.append(pooling_metadata.prompt_lens[i])

dense_output = []

if dense_hidden_states:
dense_hidden_states = torch.cat(dense_hidden_states)
dense_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=pooling_metadata.seq_data,
prompt_lens=prompt_lens)
dense_output = self.dense_pooler(dense_hidden_states,
dense_metadata).outputs

dense_it = iter(dense_output)

pooled_outputs = []

for i, hidden_states in enumerate(split_hidden_states):
if is_sparse[i]:
pooled_data = torch.squeeze(
torch.relu(self.sparse_linear(hidden_states)))
token_ids = pooling_metadata.seq_data[
seq_ids[i]].prompt_token_ids
if token_ids[0] == self.bos_token_id:
pooled_data = pooled_data[1:]
if token_ids[-1] == self.eos_token_id:
pooled_data = pooled_data[:-1]
pooled_outputs.append(self.build_output(pooled_data))
else:
pooled_outputs.append(next(dense_it))

return PoolerOutput(outputs=pooled_outputs)


class BgeM3EmbeddingModel(RobertaEmbeddingModel):
"""A model that extends RobertaEmbeddingModel with sparse embeddings.

This class supports loading an additional sparse_linear.pt file
to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.hidden_size = vllm_config.model_config.hf_config.hidden_size

self.bos_token_id = vllm_config.model_config.hf_config.bos_token_id
self.eos_token_id = vllm_config.model_config.hf_config.eos_token_id

super().__init__(vllm_config=vllm_config, prefix=prefix)
self.secondary_weight_prefix = "sparse_linear."

self.secondary_weights = [
DefaultModelLoader.Source(
model_or_path=vllm_config.model_config.model,
revision=None,
prefix=self.secondary_weight_prefix,
allow_patterns_overrides=["sparse_linear.pt"])
]

def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
dense_pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
self.sparse_linear = nn.Linear(self.hidden_size, 1)
return M3SparsePooler(dense_pooler, self.sparse_linear,
self.bos_token_id, self.eos_token_id)

def load_weights(self, all_weights: Iterable[Tuple[str, torch.Tensor]]):
secondary, weights = filter_secondary_weights(
all_weights, [self.secondary_weight_prefix])

super().load_weights(weights)

params_dict = dict(self.named_parameters())

for name, loaded_weight in secondary:
if name.startswith(self.secondary_weight_prefix):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)


class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsV0Only):
"""A model that uses Roberta to provide embedding functionalities.
Expand Down