File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -31,7 +31,7 @@ impl OrtBackend {
31
31
let pool = match model_type {
32
32
ModelType :: Classifier => Pool :: Cls ,
33
33
ModelType :: Embedding ( pool) => match pool {
34
- Pool :: Splade | Pool :: LastToken => {
34
+ Pool :: Splade => {
35
35
return Err ( BackendError :: Start ( format ! (
36
36
"Pooling {pool} is not supported for this backend. Use `candle` backend instead."
37
37
) ) ) ;
@@ -204,8 +204,10 @@ impl Backend for OrtBackend {
204
204
let pooled_embeddings = match self . pool {
205
205
// CLS pooling
206
206
Pool :: Cls => outputs. slice ( s ! [ .., 0 , ..] ) . into_owned ( ) . into_dyn ( ) ,
207
- // Last token pooling is not supported for this model
208
- Pool :: LastToken => unreachable ! ( ) ,
207
+ Pool :: LastToken => {
208
+ let axis_len = outputs. len_of ( Axis ( 1 ) ) ;
209
+ outputs. slice ( s ! [ .., axis_len - 1 , ..] ) . into_owned ( ) . into_dyn ( )
210
+ } ,
209
211
// Mean pooling
210
212
Pool :: Mean => {
211
213
if masking {
You can’t perform that action at this time.
0 commit comments