Skip to content

Commit 6b2e5a7

Browse files
authored
fix(transformers): supplement patch for fa check/fix unexpected missing_keys warning (#1371)
* fix(transformers): supplement patch for fa check * fix(transformers): fix unexpected missing_key warning * fix bugs * fix bugs
1 parent b1ef8c2 commit 6b2e5a7

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

mindone/transformers/mindspore_adapter/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@
5454
ms.bfloat16: _MAX_BF16,
5555
}
5656

57+
TORCH_TO_MINDSPORE_DTYPE_MAP = {
58+
"torch.float32": ms.float32,
59+
"torch.bfloat16": ms.bfloat16,
60+
"torch.float16": ms.float16,
61+
}
62+
5763

5864
def dtype_to_min(dtype):
5965
return _DTYPE_2_MIN.get(dtype, "others dtype")

mindone/transformers/modeling_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from .integrations.flash_attention import flash_attention_forward
6868
from .integrations.sdpa_attention import sdpa_attention_forward
6969
from .loss.loss_utils import LOSS_MAPPING
70-
from .mindspore_adapter import dtype_to_str
70+
from .mindspore_adapter import TORCH_TO_MINDSPORE_DTYPE_MAP, dtype_to_str
7171
from .mindspore_utils import ( # noqa: F401
7272
Conv1D,
7373
apply_chunking_to_forward,
@@ -771,7 +771,12 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
771771
if not getattr(config, "_attn_implementation_autoset", False):
772772
# config usually has a `mindspore_dtype` but we need the next line for the `no_super_init` tests
773773
# TODO mindspore does not have get_default_dtype api
774-
dtype = config.mindspore_dtype if hasattr(config, "mindspore_dtype") else ms.float32
774+
dtype = ms.float32
775+
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
776+
if isinstance(config.torch_dtype, str):
777+
dtype = getattr(ms, config.torch_dtype)
778+
else:
779+
dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[str(config.torch_dtype)]
775780
config = self._autoset_attn_implementation(config, mindspore_dtype=dtype)
776781
# Save config and origin of the pretrained weights if given in model
777782
self.config = config
@@ -1038,11 +1043,6 @@ def _from_config(cls, config, **kwargs):
10381043
if isinstance(mindspore_dtype, str):
10391044
mindspore_dtype = getattr(ms, mindspore_dtype)
10401045
elif mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type):
1041-
TORCH_TO_MINDSPORE_DTYPE_MAP = {
1042-
"torch.float32": ms.float32,
1043-
"torch.bfloat16": ms.bfloat16,
1044-
"torch.float16": ms.float16,
1045-
}
10461046
mindspore_dtype = str(mindspore_dtype)
10471047
mindspore_dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[mindspore_dtype]
10481048

@@ -2648,9 +2648,13 @@ def _load_pretrained_model(
26482648
loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
26492649
loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module
26502650

2651+
# Mapping loaded_keys from pt to ms
2652+
pt2ms_mappings = _get_pt2ms_mappings(model)
2653+
loaded_keys = _get_pt2ms_mapped_k(pt2ms_mappings, has_prefix_module, expects_prefix_module, loaded_keys, prefix)
2654+
26512655
# Find the key names that the model expects from the serialized keys
26522656
key_renaming_mapping = model._get_key_renaming_mapping(
2653-
original_checkpoint_keys,
2657+
loaded_keys,
26542658
key_mapping,
26552659
loading_base_model_from_task_state_dict,
26562660
loading_task_model_from_base_state_dict,

0 commit comments

Comments
 (0)