@@ -625,56 +625,6 @@ def forward(
625
625
ClassifierFn = Callable [[torch .Tensor ], torch .Tensor ]
626
626
627
627
628
- class VisionPooler (Pooler ):
629
-
630
- @classmethod
631
- def from_config (cls , model_config : ModelConfig ) -> "VisionPooler" :
632
- return cls (model_config )
633
-
634
- def __init__ (self , config : ModelConfig ):
635
- super ().__init__ ()
636
- self .config = config
637
-
638
- def get_pooling_params (self , task : PoolingTask ) -> Optional [PoolingParams ]:
639
- if task == "embed" :
640
- return PoolingParams (pooling_type = "vision" ,
641
- logits_processing_needs_token_ids = True )
642
- return None
643
-
644
- def forward (
645
- self ,
646
- hidden_states : torch .Tensor ,
647
- pooling_metadata : PoolingMetadata ,
648
- ) -> PoolerOutput :
649
- assert isinstance (pooling_metadata , V1PoolingMetadata )
650
-
651
- pooled_outputs = []
652
- for i in range (len (pooling_metadata .prompt_lens )):
653
- start_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
654
- hf_config .vision_start_token_id ).nonzero ()[- 1 ].item ()
655
- end_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
656
- hf_config .vision_end_token_id ).nonzero ()[- 1 ].item ()
657
-
658
- seq_start = torch .cumsum (
659
- torch .tensor ([0 ] + pooling_metadata .prompt_lens .tolist ()),
660
- dim = 0 )[i ]
661
- seq_len = pooling_metadata .prompt_lens [i ]
662
-
663
- output = torch .empty (self .config .hidden_size ,
664
- device = hidden_states .device ,
665
- dtype = hidden_states .dtype )
666
-
667
- grid = lambda meta : (self .config .hidden_size , )
668
- mean_pool_with_position_kernel [grid ](hidden_states , output ,
669
- seq_start , seq_len ,
670
- self .config .hidden_size ,
671
- start_pos , end_pos + 1 )
672
-
673
- pooled_outputs .append (output )
674
-
675
- return build_output (torch .stack (pooled_outputs ))
676
-
677
-
678
628
if HAS_TRITON :
679
629
680
630
@triton .jit
@@ -817,10 +767,12 @@ def forward(
817
767
818
768
pooled_outputs = []
819
769
for i in range (len (pooling_metadata .prompt_lens )):
820
- start_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
821
- hf_config .vision_start_token_id ).nonzero ()[- 1 ].item ()
822
- end_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
823
- hf_config .vision_end_token_id ).nonzero ()[- 1 ].item ()
770
+ start_pos = (pooling_metadata .prompt_token_ids [i ] ==
771
+ self .config .hf_config .vision_start_token_id ).
772
+ nonzero ()[- 1 ].item ()
773
+ end_pos = (pooling_metadata .prompt_token_ids [i ] ==
774
+ self .config .hf_config .vision_end_token_id ).
775
+ nonzero ()[- 1 ].item ()
824
776
825
777
seq_start = torch .cumsum (
826
778
torch .tensor ([0 ] + pooling_metadata .prompt_lens .tolist ()),
@@ -832,41 +784,18 @@ def forward(
832
784
dtype = hidden_states .dtype )
833
785
834
786
grid = lambda meta : (self .config .hidden_size , )
835
- mean_pool_with_position_kernel [grid ](hidden_states , output ,
836
- seq_start , seq_len ,
837
- self .config .hidden_size ,
838
- start_pos , end_pos + 1 )
787
+ if HAS_TRITON :
788
+ mean_pool_with_position_kernel [grid ](hidden_states , output ,
789
+ seq_start , seq_len ,
790
+ self .config .hidden_size ,
791
+ start_pos , end_pos + 1 )
792
+ else :
793
+ # Fallback to PyTorch implementation if Triton is not available
794
+ vision_tokens_range = hidden_states [seq_start + start_pos : seq_start + end_pos + 1 ]
795
+ output = vision_tokens_range .mean (dim = 0 )
839
796
840
797
pooled_outputs .append (output )
841
798
842
799
return build_output (torch .stack (pooled_outputs ))
843
800
844
801
845
- if HAS_TRITON :
846
-
847
- @triton .jit
848
- def mean_pool_with_position_kernel (
849
- hidden_states_ptr ,
850
- output_ptr ,
851
- seq_start ,
852
- seq_len ,
853
- hidden_size ,
854
- pool_start ,
855
- pool_end ,
856
- BLOCK_SIZE : tl .constexpr ,
857
- ):
858
- """Triton kernel to perform mean pooling over a specified token range."""
859
- pid = tl .program_id (0 )
860
-
861
- if pid >= hidden_size :
862
- return
863
-
864
- accumulator = 0.0
865
- for i in range (pool_start , pool_end ):
866
- hidden_val = tl .load (hidden_states_ptr +
867
- (seq_start + i ) * hidden_size + pid )
868
- accumulator += hidden_val
869
-
870
- # Store mean pooled result
871
- result = accumulator / (pool_end - pool_start )
872
- tl .store (output_ptr + pid , result )
0 commit comments