-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[Model] Add support for Jina Embeddings V4 #20802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
9fbc0e9
eea8462
5e247e9
9be40b2
64c06c7
56b7409
bef3df2
efa8b04
0fe30f8
062a156
edfe91a
5d12bd4
27b28f7
3bdbd17
fafd668
0c3f1bd
5c45015
9d34781
8e0578a
eb1497e
1b4f405
702fd16
5114a3c
6b501b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
CLS = 2 | ||
STEP = 3 | ||
MEAN = 4 | ||
VISION = 5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have created this new type of |
||
|
||
|
||
@dataclass(frozen=True) | ||
|
@@ -91,6 +92,8 @@ | |
|
||
if pooling_type == PoolingType.STEP: | ||
return StepPooler.from_config(resolved_config) | ||
if pooling_type == PoolingType.VISION: | ||
return VisionPooler.from_config(resolved_config) | ||
|
||
return SimplePooler.from_config(resolved_config) | ||
|
||
|
@@ -622,6 +625,86 @@ | |
ClassifierFn = Callable[[torch.Tensor], torch.Tensor] | ||
|
||
|
||
class VisionPooler(Pooler): | ||
|
||
@classmethod | ||
def from_config(cls, model_config: ModelConfig) -> "VisionPooler": | ||
return cls(model_config) | ||
|
||
def __init__(self, config: ModelConfig): | ||
super().__init__() | ||
self.config = config | ||
|
||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: | ||
if task == "embed": | ||
return PoolingParams(pooling_type="vision", | ||
logits_processing_needs_token_ids=True) | ||
return None | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
pooling_metadata: PoolingMetadata, | ||
) -> PoolerOutput: | ||
assert isinstance(pooling_metadata, V1PoolingMetadata) | ||
|
||
pooled_outputs = [] | ||
for i in range(len(pooling_metadata.prompt_lens)): | ||
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config. | ||
hf_config.vision_start_token_id).nonzero()[-1].item() | ||
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config. | ||
hf_config.vision_end_token_id).nonzero()[-1].item() | ||
|
||
seq_start = torch.cumsum( | ||
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()), | ||
dim=0)[i] | ||
seq_len = pooling_metadata.prompt_lens[i] | ||
|
||
output = torch.empty(self.config.hidden_size, | ||
device=hidden_states.device, | ||
dtype=hidden_states.dtype) | ||
|
||
grid = lambda meta: (self.config.hidden_size, ) | ||
mean_pool_with_position_kernel[grid](hidden_states, output, | ||
seq_start, seq_len, | ||
self.config.hidden_size, | ||
start_pos, end_pos + 1) | ||
|
||
pooled_outputs.append(output) | ||
|
||
return build_output(torch.stack(pooled_outputs)) | ||
|
||
|
||
if HAS_TRITON: | ||
|
||
@triton.jit | ||
def mean_pool_with_position_kernel( | ||
hidden_states_ptr, | ||
output_ptr, | ||
seq_start, | ||
seq_len, | ||
hidden_size, | ||
pool_start, | ||
pool_end, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
"""Triton kernel to perform mean pooling over a specified token range.""" | ||
pid = tl.program_id(0) | ||
|
||
if pid >= hidden_size: | ||
return | ||
|
||
accumulator = 0.0 | ||
for i in range(pool_start, pool_end): | ||
hidden_val = tl.load(hidden_states_ptr + | ||
(seq_start + i) * hidden_size + pid) | ||
accumulator += hidden_val | ||
|
||
# Store mean pooled result | ||
result = accumulator / (pool_end - pool_start) | ||
tl.store(output_ptr + pid, result) | ||
|
||
|
||
class ClassifierPooler(nn.Module): | ||
"""A pooling layer for classification tasks. | ||
|
||
|
@@ -640,7 +723,7 @@ | |
pooling: PoolingFn, | ||
classifier: ClassifierFn, | ||
act_fn: Optional[PoolerActivation] = None, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.pooling = pooling | ||
|
@@ -709,39 +792,81 @@ | |
return build_output(scores) | ||
|
||
|
||
class VisionPooler(Pooler): | ||
sigridjineth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def from_config(cls, model_config: ModelConfig) -> "VisionPooler": | ||
return cls(model_config) | ||
|
||
def __init__(self, config: ModelConfig): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we pass in the token IDs and hidden size explicitly? In case other models store those attributes in different locations |
||
super().__init__() | ||
self.config = config | ||
|
||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: | ||
if task == "embed": | ||
return PoolingParams(pooling_type="vision", | ||
logits_processing_needs_token_ids=True) | ||
return None | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
pooling_metadata: PoolingMetadata, | ||
) -> PoolerOutput: | ||
assert isinstance(pooling_metadata, V1PoolingMetadata) | ||
|
||
pooled_outputs = [] | ||
for i in range(len(pooling_metadata.prompt_lens)): | ||
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config. | ||
hf_config.vision_start_token_id).nonzero()[-1].item() | ||
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config. | ||
hf_config.vision_end_token_id).nonzero()[-1].item() | ||
|
||
seq_start = torch.cumsum( | ||
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()), | ||
dim=0)[i] | ||
seq_len = pooling_metadata.prompt_lens[i] | ||
|
||
output = torch.empty(self.config.hidden_size, | ||
device=hidden_states.device, | ||
dtype=hidden_states.dtype) | ||
|
||
grid = lambda meta: (self.config.hidden_size, ) | ||
mean_pool_with_position_kernel[grid](hidden_states, output, | ||
seq_start, seq_len, | ||
Check failure on line 836 in vllm/model_executor/layers/pooler.py
|
||
self.config.hidden_size, | ||
start_pos, end_pos + 1) | ||
|
||
pooled_outputs.append(output) | ||
|
||
return build_output(torch.stack(pooled_outputs)) | ||
|
||
|
||
if HAS_TRITON: | ||
|
||
@triton.jit | ||
def extract_vision_tokens_kernel( | ||
def mean_pool_with_position_kernel( | ||
hidden_states_ptr, | ||
token_ids_ptr, | ||
output_ptr, | ||
seq_start, | ||
seq_len, | ||
hidden_size, | ||
vision_start_id: tl.constexpr, | ||
vision_end_id: tl.constexpr, | ||
pool_start, | ||
pool_end, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
"""Triton kernel to extract and pool vision tokens efficiently.""" | ||
"""Triton kernel to perform mean pooling over a specified token range.""" | ||
pid = tl.program_id(0) | ||
|
||
if pid >= hidden_size: | ||
return | ||
|
||
# Find vision token range | ||
vision_count = 0 | ||
accumulator = 0.0 | ||
|
||
for i in range(seq_len): | ||
token_id = tl.load(token_ids_ptr + seq_start + i) | ||
if token_id >= vision_start_id and token_id <= vision_end_id: | ||
hidden_val = tl.load(hidden_states_ptr + | ||
(seq_start + i) * hidden_size + pid) | ||
accumulator += hidden_val | ||
vision_count += 1 | ||
for i in range(pool_start, pool_end): | ||
hidden_val = tl.load(hidden_states_ptr + | ||
(seq_start + i) * hidden_size + pid) | ||
accumulator += hidden_val | ||
|
||
# Store mean pooled result | ||
result = accumulator / vision_count if vision_count > 0 else 0.0 | ||
|
||
result = accumulator / (pool_end - pool_start) | ||
tl.store(output_ptr + pid, result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the pooling type here is supposed to be upper case