Skip to content

Commit fb80177

Browse files
Add last token pooling support for ORT. (#664)
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
1 parent 4c098cc commit fb80177

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

backends/ort/src/lib.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ impl OrtBackend {
3131
let pool = match model_type {
3232
ModelType::Classifier => Pool::Cls,
3333
ModelType::Embedding(pool) => match pool {
34-
Pool::Splade | Pool::LastToken => {
34+
Pool::Splade => {
3535
return Err(BackendError::Start(format!(
3636
"Pooling {pool} is not supported for this backend. Use `candle` backend instead."
3737
)));
@@ -204,8 +204,10 @@ impl Backend for OrtBackend {
204204
let pooled_embeddings = match self.pool {
205205
// CLS pooling
206206
Pool::Cls => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(),
207-
// Last token pooling is not supported for this model
208-
Pool::LastToken => unreachable!(),
207+
Pool::LastToken => {
208+
let axis_len = outputs.len_of(Axis(1));
209+
outputs.slice(s![.., axis_len - 1, ..]).into_owned().into_dyn()
210+
},
209211
// Mean pooling
210212
Pool::Mean => {
211213
if masking {

0 commit comments

Comments
 (0)