|
67 | 67 | from .integrations.flash_attention import flash_attention_forward |
68 | 68 | from .integrations.sdpa_attention import sdpa_attention_forward |
69 | 69 | from .loss.loss_utils import LOSS_MAPPING |
70 | | -from .mindspore_adapter import dtype_to_str |
| 70 | +from .mindspore_adapter import TORCH_TO_MINDSPORE_DTYPE_MAP, dtype_to_str |
71 | 71 | from .mindspore_utils import ( # noqa: F401 |
72 | 72 | Conv1D, |
73 | 73 | apply_chunking_to_forward, |
@@ -771,7 +771,12 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): |
771 | 771 | if not getattr(config, "_attn_implementation_autoset", False): |
772 | 772 | # config usually has a `mindspore_dtype` but we need the next line for the `no_super_init` tests |
773 | 773 | # TODO mindspore does not have get_default_dtype api |
774 | | - dtype = config.mindspore_dtype if hasattr(config, "mindspore_dtype") else ms.float32 |
| 774 | + dtype = ms.float32 |
| 775 | + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: |
| 776 | + if isinstance(config.torch_dtype, str): |
| 777 | + dtype = getattr(ms, config.torch_dtype) |
| 778 | + else: |
| 779 | + dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[str(config.torch_dtype)] |
775 | 780 | config = self._autoset_attn_implementation(config, mindspore_dtype=dtype) |
776 | 781 | # Save config and origin of the pretrained weights if given in model |
777 | 782 | self.config = config |
@@ -1038,11 +1043,6 @@ def _from_config(cls, config, **kwargs): |
1038 | 1043 | if isinstance(mindspore_dtype, str): |
1039 | 1044 | mindspore_dtype = getattr(ms, mindspore_dtype) |
1040 | 1045 | elif mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): |
1041 | | - TORCH_TO_MINDSPORE_DTYPE_MAP = { |
1042 | | - "torch.float32": ms.float32, |
1043 | | - "torch.bfloat16": ms.bfloat16, |
1044 | | - "torch.float16": ms.float16, |
1045 | | - } |
1046 | 1046 | mindspore_dtype = str(mindspore_dtype) |
1047 | 1047 | mindspore_dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[mindspore_dtype] |
1048 | 1048 |
|
@@ -2648,9 +2648,13 @@ def _load_pretrained_model( |
2648 | 2648 | loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module |
2649 | 2649 | loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module |
2650 | 2650 |
|
| 2651 | + # Mapping loaded_keys from pt to ms |
| 2652 | + pt2ms_mappings = _get_pt2ms_mappings(model) |
| 2653 | + loaded_keys = _get_pt2ms_mapped_k(pt2ms_mappings, has_prefix_module, expects_prefix_module, loaded_keys, prefix) |
| 2654 | + |
2651 | 2655 | # Find the key names that the model expects from the serialized keys |
2652 | 2656 | key_renaming_mapping = model._get_key_renaming_mapping( |
2653 | | - original_checkpoint_keys, |
| 2657 | + loaded_keys, |
2654 | 2658 | key_mapping, |
2655 | 2659 | loading_base_model_from_task_state_dict, |
2656 | 2660 | loading_task_model_from_base_state_dict, |
|
0 commit comments