diff --git a/mindnlp/transformers/modeling_utils.py b/mindnlp/transformers/modeling_utils.py index 9fb31dc4d..b422e7577 100644 --- a/mindnlp/transformers/modeling_utils.py +++ b/mindnlp/transformers/modeling_utils.py @@ -184,8 +184,8 @@ def wrapper( pretrained_model_name_or_path, **kwargs, ): - device_map = kwargs.pop("device_map", None) - sharded_metadata = kwargs.pop("sharded_metadata", None) + device_map = kwargs.get("device_map", None) + sharded_metadata = kwargs.get("sharded_metadata", None) # if device_map is not None and not initialize distribute module, raise Error. if device_map is not None: diff --git a/mindtorch/nn/functional.py b/mindtorch/nn/functional.py index 67e5c2af0..01d0aa12d 100644 --- a/mindtorch/nn/functional.py +++ b/mindtorch/nn/functional.py @@ -1220,7 +1220,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias - attn_weight = softmax(attn_weight, dim=-1, dtype=mindtorch.float32).to(query.dtype) + attn_weight = softmax(attn_weight, dim=-1) attn_weight = dropout(attn_weight, dropout_p, training=True) return attn_weight @ value