Skip to content

Commit de70b7d

Browse files
authored
Fix Qwen3 Embedding Float16 DType (#663)
1 parent fb80177 commit de70b7d

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

backends/candle/src/models/qwen3.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,13 @@ impl Qwen3Model {
455455
let causal_mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
456456
let causal_mask = causal_mask.expand(&[bs, dim, seq_len, seq_len])?;
457457

458+
let min_value = match self.dtype {
459+
DType::F32 => f32::MIN,
460+
_ => -65504.0, // f16 minimum value
461+
};
462+
458463
let negatives =
459-
Tensor::full(f32::MIN, attention_bias.shape(), device)?.to_dtype(self.dtype)?;
464+
Tensor::full(min_value, attention_bias.shape(), device)?.to_dtype(self.dtype)?;
460465
let zeros = Tensor::zeros_like(&attention_bias)?.to_dtype(self.dtype)?;
461466

462467
let causal_mask = causal_mask
@@ -514,7 +519,8 @@ impl Qwen3Model {
514519

515520
let attention_bias = if masking {
516521
let attention_bias =
517-
Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?;
522+
Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?
523+
.to_dtype(self.dtype)?;
518524
// Broadcast once instead of at every layer
519525
let attention_bias = attention_bias
520526
.broadcast_as((batch_size, self.num_attention_heads, max_length, max_length))?

backends/src/lib.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,8 @@ impl Backend {
150150
}
151151

152152
max_input_length = std::cmp::min(max_input_length, max_warmup_length);
153-
let mut seq_lengths: Vec<usize> = generate_bucket_sizes(
154-
seq_bucket_size,
155-
max_input_length,
156-
seq_len_exp_base,
157-
);
153+
let mut seq_lengths: Vec<usize> =
154+
generate_bucket_sizes(seq_bucket_size, max_input_length, seq_len_exp_base);
158155
if let Some(&last) = seq_lengths.last() {
159156
if last < max_input_length {
160157
seq_lengths.push(max_input_length);

0 commit comments

Comments
 (0)