Skip to content

[Modular] More Updates for Custom Code Loading #11969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 65 additions & 116 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
</Tip>
"""

config_name = "config.json"
config_name = "modular_config.json"
Copy link
Collaborator Author

@DN6 DN6 Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with modular_model_index.json. Also could be cases where a repo contains model weights/config file and a modular pipeline block to load the model. We can avoid conflicts with the configs this way.

model_name = None

@classmethod
Expand All @@ -344,6 +344,16 @@ def expected_components(self) -> List[ComponentSpec]:
def expected_configs(self) -> List[ConfigSpec]:
return []

@property
def intermediate_inputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []

@property
def intermediate_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []

@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -425,6 +435,60 @@ def init_pipeline(
)
return modular_pipeline

def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}
state_inputs = self.inputs + self.intermediate_inputs

# Check inputs
for input_param in state_inputs:
if input_param.name:
value = state.get_input(input_param.name) or state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value

elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) or state.get_intermediate_kwargs(
input_param.kwargs_type
)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v

return BlockState(**data)

def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.set_intermediate(output_param.name, param, output_param.kwargs_type)

for input_param in self.intermediate_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(param_name, param, input_param.kwargs_type)

@staticmethod
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
"""
Expand Down Expand Up @@ -654,51 +718,6 @@ def doc(self):
expected_configs=self.expected_configs,
)

# YiYi TODO: input and inteermediate inputs with same name? should warn?
def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}

# Check inputs
for input_param in self.inputs:
if input_param.name:
value = state.get_input(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v

# Check intermediates
for input_param in self.intermediate_inputs:
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
if intermediate_kwargs:
for k, v in intermediate_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
data[input_param.kwargs_type][k] = v
return BlockState(**data)

def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
Expand Down Expand Up @@ -1635,75 +1654,6 @@ def loop_step(self, components, state: PipelineState, **kwargs):
def __call__(self, components, state: PipelineState) -> PipelineState:
raise NotImplementedError("`__call__` method needs to be implemented by the subclass")

def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}

# Check inputs
for input_param in self.inputs:
if input_param.name:
value = state.get_input(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v

# Check intermediates
for input_param in self.intermediate_inputs:
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
if intermediate_kwargs:
for k, v in intermediate_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
data[input_param.kwargs_type][k] = v
return BlockState(**data)

def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.set_intermediate(output_param.name, param, output_param.kwargs_type)

for input_param in self.intermediate_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(param_name, param, input_param.kwargs_type)

@property
def doc(self):
return make_doc_string(
Expand Down Expand Up @@ -1976,7 +1926,6 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =

# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs

intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
Expand Down
Loading