@@ -455,8 +455,13 @@ impl Qwen3Model {
455
455
let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , device) ?;
456
456
let causal_mask = causal_mask. expand ( & [ bs, dim, seq_len, seq_len] ) ?;
457
457
458
+ let min_value = match self . dtype {
459
+ DType :: F32 => f32:: MIN ,
460
+ _ => -65504.0 , // f16 minimum value
461
+ } ;
462
+
458
463
let negatives =
459
- Tensor :: full ( f32 :: MIN , attention_bias. shape ( ) , device) ?. to_dtype ( self . dtype ) ?;
464
+ Tensor :: full ( min_value , attention_bias. shape ( ) , device) ?. to_dtype ( self . dtype ) ?;
460
465
let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_dtype ( self . dtype ) ?;
461
466
462
467
let causal_mask = causal_mask
@@ -514,7 +519,8 @@ impl Qwen3Model {
514
519
515
520
let attention_bias = if masking {
516
521
let attention_bias =
517
- Tensor :: from_vec ( attention_bias, ( batch_size, 1 , 1 , max_length) , & self . device ) ?;
522
+ Tensor :: from_vec ( attention_bias, ( batch_size, 1 , 1 , max_length) , & self . device ) ?
523
+ . to_dtype ( self . dtype ) ?;
518
524
// Broadcast once instead of at every layer
519
525
let attention_bias = attention_bias
520
526
. broadcast_as ( ( batch_size, self . num_attention_heads , max_length, max_length) ) ?
0 commit comments