48
48
format_inputs_short ,
49
49
format_intermediates_short ,
50
50
make_doc_string ,
51
+ InsertableOrderedDict
51
52
)
52
53
from .components_manager import ComponentsManager
53
54
from ..utils .dynamic_modules_utils import get_class_from_dynamic_module , resolve_trust_remote_code
66
67
)
67
68
68
69
70
+
69
71
@dataclass
70
72
class PipelineState :
71
73
"""
@@ -622,7 +624,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
622
624
block_trigger_inputs = []
623
625
624
626
def __init__ (self ):
625
- blocks = OrderedDict ()
627
+ blocks = InsertableOrderedDict ()
626
628
for block_name , block_cls in zip (self .block_names , self .block_classes ):
627
629
blocks [block_name ] = block_cls ()
628
630
self .blocks = blocks
@@ -958,7 +960,7 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo
958
960
return instance
959
961
960
962
def __init__ (self ):
961
- blocks = OrderedDict ()
963
+ blocks = InsertableOrderedDict ()
962
964
for block_name , block_cls in zip (self .block_names , self .block_classes ):
963
965
blocks [block_name ] = block_cls ()
964
966
self .blocks = blocks
@@ -1449,7 +1451,7 @@ def outputs(self) -> List[str]:
1449
1451
1450
1452
1451
1453
def __init__ (self ):
1452
- blocks = OrderedDict ()
1454
+ blocks = InsertableOrderedDict ()
1453
1455
for block_name , block_cls in zip (self .block_names , self .block_classes ):
1454
1456
blocks [block_name ] = block_cls ()
1455
1457
self .blocks = blocks
@@ -1662,6 +1664,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
1662
1664
1663
1665
"""
1664
1666
config_name = "modular_model_index.json"
1667
+ hf_device_map = None
1665
1668
1666
1669
1667
1670
def register_components (self , ** kwargs ):
@@ -2013,7 +2016,26 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs):
2013
2016
# Register all components at once
2014
2017
self .register_components (** components_to_register )
2015
2018
2016
- # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
2019
+ # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._maybe_raise_error_if_group_offload_active
2020
+ def _maybe_raise_error_if_group_offload_active (
2021
+ self , raise_error : bool = False , module : Optional [torch .nn .Module ] = None
2022
+ ) -> bool :
2023
+ from ..hooks .group_offloading import _is_group_offload_enabled
2024
+
2025
+ components = self .components .values () if module is None else [module ]
2026
+ components = [component for component in components if isinstance (component , torch .nn .Module )]
2027
+ for component in components :
2028
+ if _is_group_offload_enabled (component ):
2029
+ if raise_error :
2030
+ raise ValueError (
2031
+ "You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
2032
+ "with group offloading enabled. This is not supported. Please disable group offloading for "
2033
+ "components of the pipeline to use other offloading methods."
2034
+ )
2035
+ return True
2036
+ return False
2037
+
2038
+ # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
2017
2039
def to (self , * args , ** kwargs ) -> Self :
2018
2040
r"""
2019
2041
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
@@ -2050,6 +2072,10 @@ def to(self, *args, **kwargs) -> Self:
2050
2072
Returns:
2051
2073
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
2052
2074
"""
2075
+ from ..pipelines .pipeline_utils import _check_bnb_status , DiffusionPipeline
2076
+ from ..utils import is_accelerate_available , is_accelerate_version , is_hpu_available , is_transformers_version
2077
+
2078
+
2053
2079
dtype = kwargs .pop ("dtype" , None )
2054
2080
device = kwargs .pop ("device" , None )
2055
2081
silence_dtype_warnings = kwargs .pop ("silence_dtype_warnings" , False )
@@ -2152,8 +2178,7 @@ def module_is_offloaded(module):
2152
2178
os .environ ["PT_HPU_MAX_COMPOUND_OP_SIZE" ] = "1"
2153
2179
logger .debug ("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1" )
2154
2180
2155
- module_names , _ = self ._get_signature_keys (self )
2156
- modules = [getattr (self , n , None ) for n in module_names ]
2181
+ modules = self .components .values ()
2157
2182
modules = [m for m in modules if isinstance (m , torch .nn .Module )]
2158
2183
2159
2184
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
@@ -2431,4 +2456,12 @@ def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = No
2431
2456
2432
2457
@property
2433
2458
def doc (self ):
2434
- return self .blocks .doc
2459
+ return self .blocks .doc
2460
+
2461
+ def to (self , * args , ** kwargs ):
2462
+ self .loader .to (* args , ** kwargs )
2463
+ return self
2464
+
2465
+ @property
2466
+ def components (self ):
2467
+ return self .loader .components
0 commit comments