Skip to content

Commit 7fd83f5

Browse files
committed
fix baichuan finfo error
1 parent 0fb2d1c commit 7fd83f5

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

mindnlp/transformers/models/baichuan/modeling_baichuan.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _make_causal_mask(
280280
"""
281281
bsz, tgt_len = input_ids_shape
282282
mask = ops.full(
283-
(tgt_len, tgt_len), float(ops.finfo(dtype).min), dtype)
283+
(tgt_len, tgt_len), float(ops.finfo(dtype).min), dtype=dtype)
284284
mask_cond = ops.arange(mask.shape[-1])
285285
mask = ops.masked_fill(mask, mask_cond < (mask_cond + 1).view(mask.shape[-1], 1), 0.)
286286
mask = mask.to(dtype)
@@ -309,7 +309,7 @@ def _expand_mask(mask: Tensor, dtype: mstype, tgt_len: Optional[int] = None):
309309

310310
return inverted_mask.masked_fill(
311311
inverted_mask.to(mindspore.bool_),
312-
ops.finfo(dtype).min)
312+
float(ops.finfo(dtype).min))
313313

314314
def _get_interleave(n):
315315
"""
@@ -688,8 +688,7 @@ def forward(
688688
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}"
689689
)
690690
attn_weights = attn_weights + attention_mask
691-
attn_weights = ops.maximum(attn_weights,
692-
Tensor(np.finfo(mindspore.dtype_to_nptype(attn_weights.dtype)).min))
691+
attn_weights = ops.maximum(attn_weights, float(ops.finfo(attn_weights.dtype).min))
693692

694693
# upcast attention to fp32
695694
attn_weights = F.softmax(attn_weights, dim=-1).astype(query_states.dtype)
@@ -882,7 +881,7 @@ def forward(
882881
else:
883882
attention_mask = attention_mask[:, -1:, :]
884883
attn_weights = attn_weights + attention_mask.astype(attn_weights.dtype)
885-
attn_weights = ops.maximum(attn_weights, mindspore.tensor(np.finfo(mindspore.dtype_to_nptype(attn_weights.dtype)).min))
884+
attn_weights = ops.maximum(attn_weights, float(ops.finfo(attn_weights.dtype).min))
886885

887886
attn_weights = F.softmax(attn_weights, dim=-1)
888887

@@ -1561,7 +1560,7 @@ def forward(
15611560
src_len, tgt_len = alibi_mask.shape[-2:]
15621561
expanded_mask = expanded_mask.unsqueeze(1).broadcast_to((bsz, 1, src_len, tgt_len)).to(alibi_mask.dtype)
15631562
inverted_mask = 1.0 - expanded_mask
1564-
inverted_mask = inverted_mask.masked_fill(inverted_mask.to(mindspore.bool_), np.finfo(mindspore.dtype_to_nptype(alibi_mask.dtype)).min)
1563+
inverted_mask = inverted_mask.masked_fill(inverted_mask.to(mindspore.bool_), float(ops.finfo(alibi_mask.dtype).min))
15651564
attention_mask = inverted_mask + alibi_mask.unsqueeze(0)
15661565
else:
15671566
attention_mask = alibi_mask
@@ -1854,7 +1853,7 @@ def prepare_inputs_for_generation(
18541853
position_ids = kwargs.get("position_ids", None)
18551854
if attention_mask is not None and position_ids is None:
18561855
# create position_ids on the fly for batch generation
1857-
position_ids = attention_mask.long().cumsum(-1) - 1
1856+
position_ids = attention_mask.int().cumsum(-1) - 1
18581857
position_ids = position_ids.masked_fill(attention_mask == 0, 1)
18591858
if past_key_values:
18601859
position_ids = position_ids[:, -1].unsqueeze(-1)

0 commit comments

Comments
 (0)