File tree Expand file tree Collapse file tree 3 files changed +8
-3
lines changed Expand file tree Collapse file tree 3 files changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -804,7 +804,7 @@ impl BertModel {
804
804
let input_ids = Tensor :: from_vec ( input_ids, shape, & self . device ) ?;
805
805
let type_ids = Tensor :: from_vec ( type_ids, shape, & self . device ) ?;
806
806
let position_ids = Tensor :: from_vec ( position_ids, shape, & self . device ) ?;
807
- let input_lengths =
807
+ let mut input_lengths =
808
808
Tensor :: from_vec ( input_lengths, ( batch_size, 1 ) , & self . device ) ?. to_dtype ( self . dtype ) ?;
809
809
810
810
let embedding_output = self
@@ -847,6 +847,7 @@ impl BertModel {
847
847
if let Some ( pooled_indices) = pooled_indices {
848
848
// Select values in the batch
849
849
attention_mask = attention_mask. index_select ( & pooled_indices, 0 ) ?;
850
+ input_lengths = input_lengths. index_select ( & pooled_indices, 0 ) ?;
850
851
} ;
851
852
852
853
// Mask padded values
Original file line number Diff line number Diff line change @@ -209,18 +209,22 @@ impl Backend for OrtBackend {
209
209
Pool :: Mean => {
210
210
if masking {
211
211
let mut attention_mask = attention_mask;
212
+ let mut input_lengths = input_lengths;
212
213
213
214
if let Some ( indices) = indices {
214
215
// Select values in the batch
215
216
attention_mask = attention_mask. select ( Axis ( 0 ) , & indices) ;
217
+ input_lengths = input_lengths. select ( Axis ( 0 ) , & indices) ;
216
218
} ;
217
219
218
220
// Cast and reshape
219
221
let attention_mask = attention_mask. mapv ( |x| x as f32 ) . insert_axis ( Axis ( 2 ) ) ;
220
222
221
223
// Mask padded values
222
224
outputs = outputs. mul ( attention_mask) ;
223
- outputs. sum_axis ( Axis ( 1 ) ) . div ( input_lengths)
225
+ outputs
226
+ . sum_axis ( Axis ( 1 ) )
227
+ . div ( input_lengths. insert_axis ( Axis ( 1 ) ) )
224
228
} else {
225
229
outputs. mean_axis ( Axis ( 1 ) ) . unwrap ( )
226
230
}
Original file line number Diff line number Diff line change @@ -27,7 +27,7 @@ export const options = {
27
27
executor : 'constant-arrival-rate' ,
28
28
duration : '30s' ,
29
29
preAllocatedVUs : 5000 ,
30
- rate : 10 ,
30
+ rate : 50 ,
31
31
timeUnit : '1s' ,
32
32
gracefulStop : '1s' ,
33
33
} ,
You can’t perform that action at this time.
0 commit comments