@@ -280,7 +280,7 @@ def _make_causal_mask(
280
280
"""
281
281
bsz , tgt_len = input_ids_shape
282
282
mask = ops .full (
283
- (tgt_len , tgt_len ), float (ops .finfo (dtype ).min ), dtype )
283
+ (tgt_len , tgt_len ), float (ops .finfo (dtype ).min ), dtype = dtype )
284
284
mask_cond = ops .arange (mask .shape [- 1 ])
285
285
mask = ops .masked_fill (mask , mask_cond < (mask_cond + 1 ).view (mask .shape [- 1 ], 1 ), 0. )
286
286
mask = mask .to (dtype )
@@ -309,7 +309,7 @@ def _expand_mask(mask: Tensor, dtype: mstype, tgt_len: Optional[int] = None):
309
309
310
310
return inverted_mask .masked_fill (
311
311
inverted_mask .to (mindspore .bool_ ),
312
- ops .finfo (dtype ).min )
312
+ float ( ops .finfo (dtype ).min ) )
313
313
314
314
def _get_interleave (n ):
315
315
"""
@@ -688,8 +688,7 @@ def forward(
688
688
f"Attention mask should be of size { (bsz , 1 , q_len , kv_seq_len )} , but is { attention_mask .shape } "
689
689
)
690
690
attn_weights = attn_weights + attention_mask
691
- attn_weights = ops .maximum (attn_weights ,
692
- Tensor (np .finfo (mindspore .dtype_to_nptype (attn_weights .dtype )).min ))
691
+ attn_weights = ops .maximum (attn_weights , float (ops .finfo (attn_weights .dtype ).min ))
693
692
694
693
# upcast attention to fp32
695
694
attn_weights = F .softmax (attn_weights , dim = - 1 ).astype (query_states .dtype )
@@ -882,7 +881,7 @@ def forward(
882
881
else :
883
882
attention_mask = attention_mask [:, - 1 :, :]
884
883
attn_weights = attn_weights + attention_mask .astype (attn_weights .dtype )
885
- attn_weights = ops .maximum (attn_weights , mindspore . tensor ( np .finfo (mindspore . dtype_to_nptype ( attn_weights .dtype ) ).min ))
884
+ attn_weights = ops .maximum (attn_weights , float ( ops .finfo (attn_weights .dtype ).min ))
886
885
887
886
attn_weights = F .softmax (attn_weights , dim = - 1 )
888
887
@@ -1561,7 +1560,7 @@ def forward(
1561
1560
src_len , tgt_len = alibi_mask .shape [- 2 :]
1562
1561
expanded_mask = expanded_mask .unsqueeze (1 ).broadcast_to ((bsz , 1 , src_len , tgt_len )).to (alibi_mask .dtype )
1563
1562
inverted_mask = 1.0 - expanded_mask
1564
- inverted_mask = inverted_mask .masked_fill (inverted_mask .to (mindspore .bool_ ), np .finfo (mindspore . dtype_to_nptype ( alibi_mask .dtype )) .min )
1563
+ inverted_mask = inverted_mask .masked_fill (inverted_mask .to (mindspore .bool_ ), float ( ops .finfo (alibi_mask .dtype ).min ) )
1565
1564
attention_mask = inverted_mask + alibi_mask .unsqueeze (0 )
1566
1565
else :
1567
1566
attention_mask = alibi_mask
@@ -1854,7 +1853,7 @@ def prepare_inputs_for_generation(
1854
1853
position_ids = kwargs .get ("position_ids" , None )
1855
1854
if attention_mask is not None and position_ids is None :
1856
1855
# create position_ids on the fly for batch generation
1857
- position_ids = attention_mask .long ().cumsum (- 1 ) - 1
1856
+ position_ids = attention_mask .int ().cumsum (- 1 ) - 1
1858
1857
position_ids = position_ids .masked_fill (attention_mask == 0 , 1 )
1859
1858
if past_key_values :
1860
1859
position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
0 commit comments