@@ -32,6 +32,7 @@ class PoolingType(IntEnum):
32
32
CLS = 2
33
33
STEP = 3
34
34
MEAN = 4
35
+ VISION = 5
35
36
36
37
37
38
@dataclass (frozen = True )
@@ -91,6 +92,8 @@ def from_config_with_defaults(
91
92
92
93
if pooling_type == PoolingType .STEP :
93
94
return StepPooler .from_config (resolved_config )
95
+ if pooling_type == PoolingType .VISION :
96
+ return VisionPooler .from_config (resolved_config )
94
97
95
98
return SimplePooler .from_config (resolved_config )
96
99
@@ -622,6 +625,86 @@ def forward(
622
625
ClassifierFn = Callable [[torch .Tensor ], torch .Tensor ]
623
626
624
627
628
+ class VisionPooler (Pooler ):
629
+
630
+ @classmethod
631
+ def from_config (cls , model_config : ModelConfig ) -> "VisionPooler" :
632
+ return cls (model_config )
633
+
634
+ def __init__ (self , config : ModelConfig ):
635
+ super ().__init__ ()
636
+ self .config = config
637
+
638
+ def get_pooling_params (self , task : PoolingTask ) -> Optional [PoolingParams ]:
639
+ if task == "embed" :
640
+ return PoolingParams (pooling_type = "vision" ,
641
+ logits_processing_needs_token_ids = True )
642
+ return None
643
+
644
+ def forward (
645
+ self ,
646
+ hidden_states : torch .Tensor ,
647
+ pooling_metadata : PoolingMetadata ,
648
+ ) -> PoolerOutput :
649
+ assert isinstance (pooling_metadata , V1PoolingMetadata )
650
+
651
+ pooled_outputs = []
652
+ for i in range (len (pooling_metadata .prompt_lens )):
653
+ start_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
654
+ hf_config .vision_start_token_id ).nonzero ()[- 1 ].item ()
655
+ end_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
656
+ hf_config .vision_end_token_id ).nonzero ()[- 1 ].item ()
657
+
658
+ seq_start = torch .cumsum (
659
+ torch .tensor ([0 ] + pooling_metadata .prompt_lens .tolist ()),
660
+ dim = 0 )[i ]
661
+ seq_len = pooling_metadata .prompt_lens [i ]
662
+
663
+ output = torch .empty (self .config .hidden_size ,
664
+ device = hidden_states .device ,
665
+ dtype = hidden_states .dtype )
666
+
667
+ grid = lambda meta : (self .config .hidden_size , )
668
+ mean_pool_with_position_kernel [grid ](hidden_states , output ,
669
+ seq_start , seq_len ,
670
+ self .config .hidden_size ,
671
+ start_pos , end_pos + 1 )
672
+
673
+ pooled_outputs .append (output )
674
+
675
+ return build_output (torch .stack (pooled_outputs ))
676
+
677
+
678
+ if HAS_TRITON :
679
+
680
+ @triton .jit
681
+ def mean_pool_with_position_kernel (
682
+ hidden_states_ptr ,
683
+ output_ptr ,
684
+ seq_start ,
685
+ seq_len ,
686
+ hidden_size ,
687
+ pool_start ,
688
+ pool_end ,
689
+ BLOCK_SIZE : tl .constexpr ,
690
+ ):
691
+ """Triton kernel to perform mean pooling over a specified token range."""
692
+ pid = tl .program_id (0 )
693
+
694
+ if pid >= hidden_size :
695
+ return
696
+
697
+ accumulator = 0.0
698
+ for i in range (pool_start , pool_end ):
699
+ hidden_val = tl .load (hidden_states_ptr +
700
+ (seq_start + i ) * hidden_size + pid )
701
+ accumulator += hidden_val
702
+
703
+ # Store mean pooled result
704
+ result = accumulator / (pool_end - pool_start )
705
+ tl .store (output_ptr + pid , result )
706
+
707
+
625
708
class ClassifierPooler (nn .Module ):
626
709
"""A pooling layer for classification tasks.
627
710
@@ -709,39 +792,81 @@ def forward(
709
792
return build_output (scores )
710
793
711
794
795
+ class VisionPooler (Pooler ):
796
+
797
+ @classmethod
798
+ def from_config (cls , model_config : ModelConfig ) -> "VisionPooler" :
799
+ return cls (model_config )
800
+
801
+ def __init__ (self , config : ModelConfig ):
802
+ super ().__init__ ()
803
+ self .config = config
804
+
805
+ def get_pooling_params (self , task : PoolingTask ) -> Optional [PoolingParams ]:
806
+ if task == "embed" :
807
+ return PoolingParams (pooling_type = "vision" ,
808
+ logits_processing_needs_token_ids = True )
809
+ return None
810
+
811
+ def forward (
812
+ self ,
813
+ hidden_states : torch .Tensor ,
814
+ pooling_metadata : PoolingMetadata ,
815
+ ) -> PoolerOutput :
816
+ assert isinstance (pooling_metadata , V1PoolingMetadata )
817
+
818
+ pooled_outputs = []
819
+ for i in range (len (pooling_metadata .prompt_lens )):
820
+ start_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
821
+ hf_config .vision_start_token_id ).nonzero ()[- 1 ].item ()
822
+ end_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
823
+ hf_config .vision_end_token_id ).nonzero ()[- 1 ].item ()
824
+
825
+ seq_start = torch .cumsum (
826
+ torch .tensor ([0 ] + pooling_metadata .prompt_lens .tolist ()),
827
+ dim = 0 )[i ]
828
+ seq_len = pooling_metadata .prompt_lens [i ]
829
+
830
+ output = torch .empty (self .config .hidden_size ,
831
+ device = hidden_states .device ,
832
+ dtype = hidden_states .dtype )
833
+
834
+ grid = lambda meta : (self .config .hidden_size , )
835
+ mean_pool_with_position_kernel [grid ](hidden_states , output ,
836
+ seq_start , seq_len ,
837
+ self .config .hidden_size ,
838
+ start_pos , end_pos + 1 )
839
+
840
+ pooled_outputs .append (output )
841
+
842
+ return build_output (torch .stack (pooled_outputs ))
843
+
844
+
712
845
if HAS_TRITON :
713
846
714
847
@triton .jit
715
- def extract_vision_tokens_kernel (
848
+ def mean_pool_with_position_kernel (
716
849
hidden_states_ptr ,
717
- token_ids_ptr ,
718
850
output_ptr ,
719
851
seq_start ,
720
852
seq_len ,
721
853
hidden_size ,
722
- vision_start_id : tl . constexpr ,
723
- vision_end_id : tl . constexpr ,
854
+ pool_start ,
855
+ pool_end ,
724
856
BLOCK_SIZE : tl .constexpr ,
725
857
):
726
- """Triton kernel to extract and pool vision tokens efficiently ."""
858
+ """Triton kernel to perform mean pooling over a specified token range ."""
727
859
pid = tl .program_id (0 )
728
860
729
861
if pid >= hidden_size :
730
862
return
731
863
732
- # Find vision token range
733
- vision_count = 0
734
864
accumulator = 0.0
735
-
736
- for i in range (seq_len ):
737
- token_id = tl .load (token_ids_ptr + seq_start + i )
738
- if token_id >= vision_start_id and token_id <= vision_end_id :
739
- hidden_val = tl .load (hidden_states_ptr +
740
- (seq_start + i ) * hidden_size + pid )
741
- accumulator += hidden_val
742
- vision_count += 1
865
+ for i in range (pool_start , pool_end ):
866
+ hidden_val = tl .load (hidden_states_ptr +
867
+ (seq_start + i ) * hidden_size + pid )
868
+ accumulator += hidden_val
743
869
744
870
# Store mean pooled result
745
- result = accumulator / vision_count if vision_count > 0 else 0.0
746
-
871
+ result = accumulator / (pool_end - pool_start )
747
872
tl .store (output_ptr + pid , result )
0 commit comments