Skip to content

Commit eb1497e

Browse files
author
Sigrid Jin (Sionic AI)
committed
refactor: use pooler utility functions to avoid duplicate code
Use build_output and get_prompt_lens from pooler.py instead of implementing duplicate logic: - Replace manual PoolerOutput construction with build_output - Replace prompt length extraction logic with get_prompt_lens - Remove unused imports (PoolingSequenceGroupOutput, PoolingTensors) This addresses the review feedback to avoid duplicate code. Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
1 parent 8e0578a commit eb1497e

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
from vllm.config import VllmConfig
1313
from vllm.logger import init_logger
1414
from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingTask,
15-
PoolingType,
16-
extract_vision_tokens_kernel)
15+
PoolingType, build_output,
16+
extract_vision_tokens_kernel,
17+
get_prompt_lens)
1718
# yapf: disable
1819
from vllm.model_executor.pooling_metadata import (
1920
PoolingMetadata as V0PoolingMetadata)
20-
from vllm.model_executor.pooling_metadata import PoolingTensors
2121
# yapf: enable
2222
from vllm.multimodal import MULTIMODAL_REGISTRY
2323
from vllm.pooling_params import PoolingParams
24-
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
24+
from vllm.sequence import PoolerOutput
2525
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
2626

2727
from .interfaces import SupportsCrossEncoding, SupportsMultiModal
@@ -84,7 +84,7 @@ def forward(
8484
# Validate inputs
8585
if hidden_states is None or hidden_states.numel() == 0:
8686
logger.warning("Empty hidden states received")
87-
return PoolerOutput(outputs=[])
87+
return build_output(torch.empty((0, 0)))
8888

8989
# Extract token IDs safely from metadata
9090
token_ids_list, seq_ids = self._extract_token_ids_safe(
@@ -95,12 +95,8 @@ def forward(
9595
# Fallback to base pooler
9696
return self._base_pooler(hidden_states, pooling_metadata)
9797

98-
# Get prompt lengths based on metadata type
99-
if isinstance(pooling_metadata, V1PoolingMetadata):
100-
prompt_lens = pooling_metadata.prompt_lens
101-
else:
102-
prompt_lens = PoolingTensors.from_pooling_metadata(
103-
pooling_metadata, hidden_states.device).prompt_lens
98+
# Get prompt lengths using utility function
99+
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
104100

105101
# Validate lengths match
106102
assert len(token_ids_list) == len(prompt_lens), (
@@ -115,10 +111,8 @@ def forward(
115111
pooled_data = self._apply_vision_pooling_pytorch(
116112
hidden_states, token_ids_list, prompt_lens)
117113

118-
# Build output
119-
pooled_outputs = [
120-
PoolingSequenceGroupOutput(data) for data in pooled_data
121-
]
114+
# Stack pooled data into tensor for build_output
115+
pooled_tensor = torch.stack(pooled_data)
122116

123117
# Record metrics
124118
if self.observability_config:
@@ -130,7 +124,7 @@ def forward(
130124
avg_time = self._pooling_time_ms / self._pooling_count
131125
logger.debug("Average pooling time: %.2fms", avg_time)
132126

133-
return PoolerOutput(outputs=pooled_outputs)
127+
return build_output(pooled_tensor)
134128

135129
def _extract_token_ids_safe(
136130
self, pooling_metadata: PoolingMetadata

0 commit comments

Comments
 (0)