Skip to content

Commit 4cc38bd

Browse files
fix(ort): fix mean pooling (#332)
1 parent e496fe7 commit 4cc38bd

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

backends/candle/src/models/bert.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ impl BertModel {
804804
let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
805805
let type_ids = Tensor::from_vec(type_ids, shape, &self.device)?;
806806
let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?;
807-
let input_lengths =
807+
let mut input_lengths =
808808
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;
809809

810810
let embedding_output = self
@@ -847,6 +847,7 @@ impl BertModel {
847847
if let Some(pooled_indices) = pooled_indices {
848848
// Select values in the batch
849849
attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
850+
input_lengths = input_lengths.index_select(&pooled_indices, 0)?;
850851
};
851852

852853
// Mask padded values

backends/ort/src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,22 @@ impl Backend for OrtBackend {
209209
Pool::Mean => {
210210
if masking {
211211
let mut attention_mask = attention_mask;
212+
let mut input_lengths = input_lengths;
212213

213214
if let Some(indices) = indices {
214215
// Select values in the batch
215216
attention_mask = attention_mask.select(Axis(0), &indices);
217+
input_lengths = input_lengths.select(Axis(0), &indices);
216218
};
217219

218220
// Cast and reshape
219221
let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2));
220222

221223
// Mask padded values
222224
outputs = outputs.mul(attention_mask);
223-
outputs.sum_axis(Axis(1)).div(input_lengths)
225+
outputs
226+
.sum_axis(Axis(1))
227+
.div(input_lengths.insert_axis(Axis(1)))
224228
} else {
225229
outputs.mean_axis(Axis(1)).unwrap()
226230
}

load_tests/load.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export const options = {
2727
executor: 'constant-arrival-rate',
2828
duration: '30s',
2929
preAllocatedVUs: 5000,
30-
rate: 10,
30+
rate: 50,
3131
timeUnit: '1s',
3232
gracefulStop: '1s',
3333
},

0 commit comments

Comments
 (0)