Skip to content

Commit 083479c

Browse files
committed
ordereddict -> insertableOrderedDict; make sure loader to method works
1 parent 04c16d0 commit 083479c

File tree

1 file changed

+40
-7
lines changed

1 file changed

+40
-7
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
format_inputs_short,
4949
format_intermediates_short,
5050
make_doc_string,
51+
InsertableOrderedDict
5152
)
5253
from .components_manager import ComponentsManager
5354
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
@@ -66,6 +67,7 @@
6667
)
6768

6869

70+
6971
@dataclass
7072
class PipelineState:
7173
"""
@@ -622,7 +624,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
622624
block_trigger_inputs = []
623625

624626
def __init__(self):
625-
blocks = OrderedDict()
627+
blocks = InsertableOrderedDict()
626628
for block_name, block_cls in zip(self.block_names, self.block_classes):
627629
blocks[block_name] = block_cls()
628630
self.blocks = blocks
@@ -958,7 +960,7 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo
958960
return instance
959961

960962
def __init__(self):
961-
blocks = OrderedDict()
963+
blocks = InsertableOrderedDict()
962964
for block_name, block_cls in zip(self.block_names, self.block_classes):
963965
blocks[block_name] = block_cls()
964966
self.blocks = blocks
@@ -1449,7 +1451,7 @@ def outputs(self) -> List[str]:
14491451

14501452

14511453
def __init__(self):
1452-
blocks = OrderedDict()
1454+
blocks = InsertableOrderedDict()
14531455
for block_name, block_cls in zip(self.block_names, self.block_classes):
14541456
blocks[block_name] = block_cls()
14551457
self.blocks = blocks
@@ -1662,6 +1664,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
16621664
16631665
"""
16641666
config_name = "modular_model_index.json"
1667+
hf_device_map = None
16651668

16661669

16671670
def register_components(self, **kwargs):
@@ -2013,7 +2016,26 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs):
20132016
# Register all components at once
20142017
self.register_components(**components_to_register)
20152018

2016-
# Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
2019+
# Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._maybe_raise_error_if_group_offload_active
2020+
def _maybe_raise_error_if_group_offload_active(
2021+
self, raise_error: bool = False, module: Optional[torch.nn.Module] = None
2022+
) -> bool:
2023+
from ..hooks.group_offloading import _is_group_offload_enabled
2024+
2025+
components = self.components.values() if module is None else [module]
2026+
components = [component for component in components if isinstance(component, torch.nn.Module)]
2027+
for component in components:
2028+
if _is_group_offload_enabled(component):
2029+
if raise_error:
2030+
raise ValueError(
2031+
"You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
2032+
"with group offloading enabled. This is not supported. Please disable group offloading for "
2033+
"components of the pipeline to use other offloading methods."
2034+
)
2035+
return True
2036+
return False
2037+
2038+
# Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
20172039
def to(self, *args, **kwargs) -> Self:
20182040
r"""
20192041
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
@@ -2050,6 +2072,10 @@ def to(self, *args, **kwargs) -> Self:
20502072
Returns:
20512073
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
20522074
"""
2075+
from ..pipelines.pipeline_utils import _check_bnb_status, DiffusionPipeline
2076+
from ..utils import is_accelerate_available, is_accelerate_version, is_hpu_available, is_transformers_version
2077+
2078+
20532079
dtype = kwargs.pop("dtype", None)
20542080
device = kwargs.pop("device", None)
20552081
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
@@ -2152,8 +2178,7 @@ def module_is_offloaded(module):
21522178
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
21532179
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
21542180

2155-
module_names, _ = self._get_signature_keys(self)
2156-
modules = [getattr(self, n, None) for n in module_names]
2181+
modules = self.components.values()
21572182
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
21582183

21592184
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
@@ -2431,4 +2456,12 @@ def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = No
24312456

24322457
@property
24332458
def doc(self):
2434-
return self.blocks.doc
2459+
return self.blocks.doc
2460+
2461+
def to(self, *args, **kwargs):
2462+
self.loader.to(*args, **kwargs)
2463+
return self
2464+
2465+
@property
2466+
def components(self):
2467+
return self.loader.components

0 commit comments

Comments
 (0)