Skip to content

Commit 4c855bc

Browse files
committed
Merge branch 'main' into add-dense
2 parents a368460 + 6b79f20 commit 4c855bc

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

backends/candle/src/models/qwen3.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,13 @@ impl Qwen3Model {
455455
let causal_mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
456456
let causal_mask = causal_mask.expand(&[bs, dim, seq_len, seq_len])?;
457457

458+
let min_value = match self.dtype {
459+
DType::F32 => f32::MIN,
460+
_ => -65504.0, // f16 minimum value
461+
};
462+
458463
let negatives =
459-
Tensor::full(f32::MIN, attention_bias.shape(), device)?.to_dtype(self.dtype)?;
464+
Tensor::full(min_value, attention_bias.shape(), device)?.to_dtype(self.dtype)?;
460465
let zeros = Tensor::zeros_like(&attention_bias)?.to_dtype(self.dtype)?;
461466

462467
let causal_mask = causal_mask
@@ -514,7 +519,8 @@ impl Qwen3Model {
514519

515520
let attention_bias = if masking {
516521
let attention_bias =
517-
Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?;
522+
Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?
523+
.to_dtype(self.dtype)?;
518524
// Broadcast once instead of at every layer
519525
let attention_bias = attention_bias
520526
.broadcast_as((batch_size, self.num_attention_heads, max_length, max_length))?

0 commit comments

Comments
 (0)