File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -258,14 +258,21 @@ def __init__(self, config: ModernBertConfig):
258
258
super ().__init__ ()
259
259
self .dense = nn .Linear (config .hidden_size , config .hidden_size ,
260
260
config .classifier_bias )
261
+ self .pooling_type = config .classifier_pooling
261
262
self .act = nn .GELU ()
262
263
self .norm = nn .LayerNorm (config .hidden_size ,
263
264
eps = config .norm_eps ,
264
265
bias = config .norm_bias )
265
266
266
267
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
267
268
pooled_output = hidden_states
268
- pooled_output = pooled_output .mean (dim = 0 , keepdim = False )
269
+ if self .pooling_type == "mean" :
270
+ pooled_output = pooled_output .mean (dim = 0 , keepdim = False )
271
+ elif self .pooling_type == "cls" :
272
+ pooled_output = pooled_output [0 , :]
273
+ else :
274
+ raise ValueError ("Pooling type should be either `cls` or `mean`, "
275
+ f"but got { self .pooling_type } " )
269
276
pooled_output = self .norm (self .act (self .dense (pooled_output )))
270
277
return pooled_output
271
278
You can’t perform that action at this time.
0 commit comments