Skip to content
35 changes: 27 additions & 8 deletions mindone/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@
DistilBertModel,
DistilBertPreTrainedModel,
)
from .models.donut import (
DonutFeatureExtractor,
DonutImageProcessor,
DonutProcessor,
DonutSwinModel,
DonutSwinPreTrainedModel,
)
from .models.dpr import (
DPRContextEncoder,
DPRPretrainedContextEncoder,
Expand Down Expand Up @@ -445,6 +452,18 @@
FlaubertPreTrainedModel,
FlaubertWithLMHeadModel,
)
from .models.flava import (
FlavaFeatureExtractor,
FlavaForPreTraining,
FlavaImageCodebook,
FlavaImageModel,
FlavaImageProcessor,
FlavaModel,
FlavaMultimodalModel,
FlavaPreTrainedModel,
FlavaProcessor,
FlavaTextModel,
)
from .models.fsmt import FSMTForConditionalGeneration, FSMTModel, PretrainedFSMTModel
from .models.funnel import (
FunnelBaseModel,
Expand Down Expand Up @@ -505,6 +524,14 @@
GPTBigCodeModel,
GPTBigCodePreTrainedModel,
)
from .models.gpt_neo import (
GPTNeoForCausalLM,
GPTNeoForQuestionAnswering,
GPTNeoForSequenceClassification,
GPTNeoForTokenClassification,
GPTNeoModel,
GPTNeoPreTrainedModel,
)
from .models.gpt_neox import (
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
Expand All @@ -527,14 +554,6 @@
GPTJModel,
GPTJPreTrainedModel,
)
from .models.gpt_neo import (
GPTNeoForCausalLM,
GPTNeoForQuestionAnswering,
GPTNeoForSequenceClassification,
GPTNeoForTokenClassification,
GPTNeoModel,
GPTNeoPreTrainedModel,
)
from .models.granite import GraniteForCausalLM, GraniteModel, GranitePreTrainedModel
from .models.granitemoe import GraniteMoeForCausalLM, GraniteMoeModel, GraniteMoePreTrainedModel
from .models.granitemoeshared import GraniteMoeSharedForCausalLM, GraniteMoeSharedModel, GraniteMoeSharedPreTrainedModel
Expand Down
2 changes: 1 addition & 1 deletion mindone/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar
pass # unexpect key keeps origin dtype
cm = silence_mindspore_logger() if is_sharded else nullcontext()
with cm:
ms.load_param_into_net(model_to_load, state_dict, strict_load=True)
model_to_load.load_state_dict(state_dict, strict=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change strict to False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of tied weights, there may be extra or missing parameters in the Hugging Face transformer checkpoint. Using strict=True will raise an error, so we follow the same design as the Transformers repo by setting strict=False.


# remove prefix from the name of parameters
if len(start_prefix) > 0:
Expand Down
4 changes: 3 additions & 1 deletion mindone/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
depth_pro,
dinov2,
distilbert,
donut,
dpr,
dpt,
efficientnet,
Expand All @@ -65,6 +66,7 @@
falcon,
fastspeech2_conformer,
flaubert,
flava,
fsmt,
funnel,
fuyu,
Expand All @@ -77,10 +79,10 @@
got_ocr2,
gpt2,
gpt_bigcode,
gpt_neo,
gpt_neox,
gpt_neox_japanese,
gptj,
gpt_neo,
granite,
granitemoe,
granitemoeshared,
Expand Down
5 changes: 5 additions & 0 deletions mindone/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
("dinov2", "Dinov2Config"),
("deit", "DeiTConfig"),
("distilbert", "DistilBertConfig"),
("donut-swin", "DonutSwinConfig"),
("dpr", "DPRConfig"),
("dpt", "DPTConfig"),
("efficientnet", "EfficientNetConfig"),
Expand All @@ -86,6 +87,7 @@
("encodec", "EncodecConfig"),
("falcon", "FalconConfig"),
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
("flava", "FlavaConfig"),
("funnel", "FunnelConfig"),
("gemma", "GemmaConfig"),
("granite", "GraniteConfig"),
Expand Down Expand Up @@ -129,6 +131,7 @@
("luke", "LukeConfig"),
("mamba", "MambaConfig"),
("mamba2", "Mamba2Config"),
("mbart", "MBartConfig"),
("mimi", "MimiConfig"),
("mistral", "MistralConfig"),
("mistral3", "Mistral3Config"),
Expand Down Expand Up @@ -272,6 +275,7 @@
("detr", "DETR"),
("dinov2", "DINOv2"),
("distilbert", "DistilBERT"),
("donut-swin", "DonutSwin"),
("dpr", "DPR"),
("dpt", "DPT"),
("efficientnet", "EfficientNet"),
Expand All @@ -280,6 +284,7 @@
("encodec", "Encodec"),
("falcon", "Falcon"),
("fastspeech2_conformer", "FastSpeech2Conformer"),
("flava", "FLAVA"),
("fsmt", "FairSeq Machine-Translation"),
("funnel", "Funnel Transformer"),
("gemma", "Gemma"),
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
[
("chinese_clip", "ChineseCLIPFeatureExtractor"),
("convnext", "ConvNextFeatureExtractor"),
("donut-swin", "DonutFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
("cvt", "ConvNextFeatureExtractor"),
]
)
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@
("depth_anything", ("DPTImageProcessor",)),
("depth_pro", ("DepthProImageProcessor",)),
("dinov2", ("BitImageProcessor",)),
("donut-swin", ("DonutImageProcessor",)),
("dpt", ("DPTImageProcessor",)),
("flava", ("FlavaImageProcessor",)),
("efficientnet", ("EfficientNetImageProcessor",)),
("llava_next", ("LlavaNextImageProcessor",)),
("llava_next_video", ("LlavaNextVideoImageProcessor",)),
Expand Down
3 changes: 3 additions & 0 deletions mindone/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@
("detr", "DetrModel"),
("dinov2", "Dinov2Model"),
("distilbert", "DistilBertModel"),
("donut-swin", "DonutSwinModel"),
("dpr", "DPRQuestionEncoder"),
("dpt", "DPTModel"),
("efficientnet", "EfficientNetModel"),
("electra", "ElectraModel"),
("encodec", "EncodecModel"),
("falcon", "FalconModel"),
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
("flava", "FlavaModel"),
("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("gemma", "GemmaModel"),
Expand Down Expand Up @@ -210,6 +212,7 @@
("data2vec-text", "Data2VecTextForMaskedLM"),
("distilbert", "DistilBertForMaskedLM"),
("electra", "ElectraForPreTraining"),
("flava", "FlavaForPreTraining"),
("fsmt", "FSMTForConditionalGeneration"),
("funnel", "FunnelForPreTraining"),
("gpt2", "GPT2LMHeadModel"),
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
("chameleon", "ChameleonProcessor"),
("chinese_clip", "ChineseCLIPProcessor"),
("colpali", "ColPaliProcessor"),
("donut", "DonutProcessor"),
("flava", "FlavaProcessor"),
("idefics", "IdeficsProcessor"),
("instructblip", "InstructBlipProcessor"),
("llava_next", "LlavaNextProcessor"),
Expand Down
20 changes: 20 additions & 0 deletions mindone/transformers/models/donut/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .feature_extraction_donut import *
from .image_processing_donut import *
from .modeling_donut_swin import *
from .processing_donut import *
38 changes: 38 additions & 0 deletions mindone/transformers/models/donut/feature_extraction_donut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature extractor class for Donut."""

import warnings

from ...utils import logging
from .image_processing_donut import DonutImageProcessor

logger = logging.get_logger(__name__)


class DonutFeatureExtractor(DonutImageProcessor):
def __init__(self, *args, **kwargs) -> None:
warnings.warn(
"The class DonutFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
" use DonutImageProcessor instead.",
FutureWarning,
)
super().__init__(*args, **kwargs)


__all__ = ["DonutFeatureExtractor"]
Loading