Skip to content

Commit 5753365

Browse files
feat(nodes): remove siglip from flux_redux, dl it jit when needed if we cannot find it
This follows the same pattern for IP Adapter w/ its CLIP Vision model. The SigLIP model is unlikely to ever change and we don't want to force the user to select it anywhere. Hardcoding it is safe and makes the UX much nicer. The alternative is a model dropdown that will likely only ever have one valid choice in it.
1 parent e35537e commit 5753365

File tree

1 file changed

+35
-11
lines changed

1 file changed

+35
-11
lines changed

invokeai/app/invocations/flux_redux.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@
2020
)
2121
from invokeai.app.invocations.model import ModelIdentifierField
2222
from invokeai.app.invocations.primitives import ImageField
23+
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
2324
from invokeai.app.services.shared.invocation_context import InvocationContext
2425
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
26+
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
27+
from invokeai.backend.model_manager.starter_models import siglip
2528
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
2629
from invokeai.backend.util.devices import TorchDevice
2730

@@ -35,16 +38,12 @@ class FluxReduxOutput(BaseInvocationOutput):
3538
)
3639

3740

38-
SIGLIP_STARTER_MODEL_NAME = "SigLIP - google/siglip-so400m-patch14-384"
39-
FLUX_REDUX_STARTER_MODEL_NAME = "FLUX Redux"
40-
41-
4241
@invocation(
4342
"flux_redux",
4443
title="FLUX Redux",
4544
tags=["ip_adapter", "control"],
4645
category="ip_adapter",
47-
version="1.0.0",
46+
version="2.0.0",
4847
classification=Classification.Prototype,
4948
)
5049
class FluxReduxInvocation(BaseInvocation):
@@ -61,11 +60,6 @@ class FluxReduxInvocation(BaseInvocation):
6160
title="FLUX Redux Model",
6261
ui_type=UIType.FluxReduxModel,
6362
)
64-
siglip_model: ModelIdentifierField = InputField(
65-
description="The SigLIP model to use.",
66-
title="SigLIP Model",
67-
ui_type=UIType.SigLipModel,
68-
)
6963

7064
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
7165
image = context.images.get_pil(self.image.image_name, "RGB")
@@ -80,7 +74,8 @@ def invoke(self, context: InvocationContext) -> FluxReduxOutput:
8074

8175
@torch.no_grad()
8276
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
83-
with context.models.load(self.siglip_model).model_on_device() as (_, siglip_pipeline):
77+
siglip_model_config = self._get_siglip_model(context)
78+
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
8479
assert isinstance(siglip_pipeline, SigLipPipeline)
8580
return siglip_pipeline.encode_image(
8681
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
@@ -93,3 +88,32 @@ def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor
9388
dtype = next(flux_redux.parameters()).dtype
9489
encoded_x = encoded_x.to(dtype=dtype)
9590
return flux_redux(encoded_x)
91+
92+
def _get_siglip_model(self, context: InvocationContext) -> AnyModelConfig:
93+
siglip_models = context.models.search_by_attrs(name=siglip.name, base=BaseModelType.Any, type=ModelType.SigLIP)
94+
95+
if not len(siglip_models) > 0:
96+
context.logger.warning(
97+
f"The SigLIP model required by FLUX Redux ({siglip.name}) is not installed. Downloading and installing now. This may take a while."
98+
)
99+
100+
# TODO(psyche): Can the probe reliably determine the type of the model? Just hardcoding it bc I don't want to experiment now
101+
config_overrides = ModelRecordChanges(name=siglip.name, type=ModelType.SigLIP)
102+
103+
# Queue the job
104+
job = context._services.model_manager.install.heuristic_import(siglip.source, config=config_overrides)
105+
106+
# Wait for up to 10 minutes - model is ~3.5GB
107+
context._services.model_manager.install.wait_for_job(job, timeout=600)
108+
109+
siglip_models = context.models.search_by_attrs(
110+
name=siglip.name,
111+
base=BaseModelType.Any,
112+
type=ModelType.SigLIP,
113+
)
114+
115+
if len(siglip_models) == 0:
116+
context.logger.error("Error while fetching SigLIP for FLUX Redux")
117+
assert len(siglip_models) == 1
118+
119+
return siglip_models[0]

0 commit comments

Comments
 (0)