Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 5 additions & 64 deletions src/diffusers/modular_pipelines/mellon_node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@
"display": "output",
"type": "controlnet",
},
"doc": {
"label": "Doc",
"display": "output",
"type": "string",
},
}


Expand Down Expand Up @@ -697,67 +702,3 @@ def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNo
blocks_names=blocks_names,
node_type=node_type,
)


# Minimal modular registry for Mellon node configs
class ModularMellonNodeRegistry:
"""Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig."""

def __init__(self):
self._registry = {}
self._initialized = False

def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]):
if not self._initialized:
_initialize_registry(self)
self._registry[pipeline_cls] = node_params

def get(self, pipeline_cls: type) -> MellonNodeConfig:
if not self._initialized:
_initialize_registry(self)
return self._registry.get(pipeline_cls, None)

def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]:
if not self._initialized:
_initialize_registry(self)
return self._registry


def _register_preset_node_types(
pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry
):
"""Register all node-type presets for a given pipeline class from a params map."""
node_configs = {}
for node_type, spec in params_map.items():
node_config = MellonNodeConfig(
inputs=spec.get("inputs", []),
model_inputs=spec.get("model_inputs", []),
outputs=spec.get("outputs", []),
blocks_names=spec.get("block_names", []),
node_type=node_type,
)
node_configs[node_type] = node_config
registry.register(pipeline_cls, node_configs)


def _initialize_registry(registry: ModularMellonNodeRegistry):
"""Initialize the registry and register all available pipeline configs."""
print("Initializing registry")

registry._initialized = True

try:
from .qwenimage.modular_pipeline import QwenImageModularPipeline
from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP

_register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry)
except Exception:
raise Exception("Failed to register QwenImageModularPipeline")

try:
from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline
from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP

_register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry)
except Exception:
raise Exception("Failed to register StableDiffusionXLModularPipeline")
95 changes: 0 additions & 95 deletions src/diffusers/modular_pipelines/qwenimage/node_utils.py

This file was deleted.

99 changes: 0 additions & 99 deletions src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py

This file was deleted.

Loading