Skip to content

Commit 9c3b58d

Browse files
authored
Handle deprecated transformer classes (#12517)
* update * update * update
1 parent 74b5fed commit 9c3b58d

File tree

3 files changed

+64
-1
lines changed

3 files changed

+64
-1
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ONNX_WEIGHTS_NAME,
3434
SAFETENSORS_WEIGHTS_NAME,
3535
WEIGHTS_NAME,
36+
_maybe_remap_transformers_class,
3637
deprecate,
3738
get_class_from_dynamic_module,
3839
is_accelerate_available,
@@ -356,6 +357,11 @@ def maybe_raise_or_warn(
356357
"""Simple helper method to raise or warn in case incorrect module has been passed"""
357358
if not is_pipeline_module:
358359
library = importlib.import_module(library_name)
360+
361+
# Handle deprecated Transformers classes
362+
if library_name == "transformers":
363+
class_name = _maybe_remap_transformers_class(class_name) or class_name
364+
359365
class_obj = getattr(library, class_name)
360366
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
361367

@@ -390,6 +396,11 @@ def simple_get_class_obj(library_name, class_name):
390396
class_obj = getattr(pipeline_module, class_name)
391397
else:
392398
library = importlib.import_module(library_name)
399+
400+
# Handle deprecated Transformers classes
401+
if library_name == "transformers":
402+
class_name = _maybe_remap_transformers_class(class_name) or class_name
403+
393404
class_obj = getattr(library, class_name)
394405

395406
return class_obj
@@ -416,6 +427,10 @@ def get_class_obj_and_candidates(
416427
# else we just import it from the library.
417428
library = importlib.import_module(library_name)
418429

430+
# Handle deprecated Transformers classes
431+
if library_name == "transformers":
432+
class_name = _maybe_remap_transformers_class(class_name) or class_name
433+
419434
class_obj = getattr(library, class_name)
420435
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
421436

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
WEIGHTS_INDEX_NAME,
3939
WEIGHTS_NAME,
4040
)
41-
from .deprecation_utils import deprecate
41+
from .deprecation_utils import _maybe_remap_transformers_class, deprecate
4242
from .doc_utils import replace_example_docstring
4343
from .dynamic_modules_utils import get_class_from_dynamic_module
4444
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video

src/diffusers/utils/deprecation_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,54 @@
44

55
from packaging import version
66

7+
from ..utils import logging
8+
9+
10+
logger = logging.get_logger(__name__)
11+
12+
# Mapping for deprecated Transformers classes to their replacements
13+
# This is used to handle models that reference deprecated class names in their configs
14+
# Reference: https://github.com/huggingface/transformers/issues/40822
15+
# Format: {
16+
# "DeprecatedClassName": {
17+
# "new_class": "NewClassName",
18+
# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
19+
# }
20+
# }
21+
_TRANSFORMERS_CLASS_REMAPPING = {
22+
"CLIPFeatureExtractor": {
23+
"new_class": "CLIPImageProcessor",
24+
"transformers_version": (">", "4.57.0"),
25+
},
26+
}
27+
28+
29+
def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
30+
"""
31+
Check if a Transformers class should be remapped to a newer version.
32+
33+
Args:
34+
class_name: The name of the class to check
35+
36+
Returns:
37+
The new class name if remapping should occur, None otherwise
38+
"""
39+
if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
40+
return None
41+
42+
from .import_utils import is_transformers_version
43+
44+
mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
45+
operation, required_version = mapping["transformers_version"]
46+
47+
# Only remap if the transformers version meets the requirement
48+
if is_transformers_version(operation, required_version):
49+
new_class = mapping["new_class"]
50+
logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
51+
return mapping["new_class"]
52+
53+
return None
54+
755

856
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
957
from .. import __version__

0 commit comments

Comments
 (0)