20
20
)
21
21
from invokeai .app .invocations .model import ModelIdentifierField
22
22
from invokeai .app .invocations .primitives import ImageField
23
+ from invokeai .app .services .model_records .model_records_base import ModelRecordChanges
23
24
from invokeai .app .services .shared .invocation_context import InvocationContext
24
25
from invokeai .backend .flux .redux .flux_redux_model import FluxReduxModel
26
+ from invokeai .backend .model_manager .config import AnyModelConfig , BaseModelType , ModelType
27
+ from invokeai .backend .model_manager .starter_models import siglip
25
28
from invokeai .backend .sig_lip .sig_lip_pipeline import SigLipPipeline
26
29
from invokeai .backend .util .devices import TorchDevice
27
30
@@ -35,16 +38,12 @@ class FluxReduxOutput(BaseInvocationOutput):
35
38
)
36
39
37
40
38
- SIGLIP_STARTER_MODEL_NAME = "SigLIP - google/siglip-so400m-patch14-384"
39
- FLUX_REDUX_STARTER_MODEL_NAME = "FLUX Redux"
40
-
41
-
42
41
@invocation (
43
42
"flux_redux" ,
44
43
title = "FLUX Redux" ,
45
44
tags = ["ip_adapter" , "control" ],
46
45
category = "ip_adapter" ,
47
- version = "1 .0.0" ,
46
+ version = "2 .0.0" ,
48
47
classification = Classification .Prototype ,
49
48
)
50
49
class FluxReduxInvocation (BaseInvocation ):
@@ -61,11 +60,6 @@ class FluxReduxInvocation(BaseInvocation):
61
60
title = "FLUX Redux Model" ,
62
61
ui_type = UIType .FluxReduxModel ,
63
62
)
64
- siglip_model : ModelIdentifierField = InputField (
65
- description = "The SigLIP model to use." ,
66
- title = "SigLIP Model" ,
67
- ui_type = UIType .SigLipModel ,
68
- )
69
63
70
64
def invoke (self , context : InvocationContext ) -> FluxReduxOutput :
71
65
image = context .images .get_pil (self .image .image_name , "RGB" )
@@ -80,7 +74,8 @@ def invoke(self, context: InvocationContext) -> FluxReduxOutput:
80
74
81
75
@torch .no_grad ()
82
76
def _siglip_encode (self , context : InvocationContext , image : Image .Image ) -> torch .Tensor :
83
- with context .models .load (self .siglip_model ).model_on_device () as (_ , siglip_pipeline ):
77
+ siglip_model_config = self ._get_siglip_model (context )
78
+ with context .models .load (siglip_model_config .key ).model_on_device () as (_ , siglip_pipeline ):
84
79
assert isinstance (siglip_pipeline , SigLipPipeline )
85
80
return siglip_pipeline .encode_image (
86
81
x = image , device = TorchDevice .choose_torch_device (), dtype = TorchDevice .choose_torch_dtype ()
@@ -93,3 +88,32 @@ def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor
93
88
dtype = next (flux_redux .parameters ()).dtype
94
89
encoded_x = encoded_x .to (dtype = dtype )
95
90
return flux_redux (encoded_x )
91
+
92
+ def _get_siglip_model (self , context : InvocationContext ) -> AnyModelConfig :
93
+ siglip_models = context .models .search_by_attrs (name = siglip .name , base = BaseModelType .Any , type = ModelType .SigLIP )
94
+
95
+ if not len (siglip_models ) > 0 :
96
+ context .logger .warning (
97
+ f"The SigLIP model required by FLUX Redux ({ siglip .name } ) is not installed. Downloading and installing now. This may take a while."
98
+ )
99
+
100
+ # TODO(psyche): Can the probe reliably determine the type of the model? Just hardcoding it bc I don't want to experiment now
101
+ config_overrides = ModelRecordChanges (name = siglip .name , type = ModelType .SigLIP )
102
+
103
+ # Queue the job
104
+ job = context ._services .model_manager .install .heuristic_import (siglip .source , config = config_overrides )
105
+
106
+ # Wait for up to 10 minutes - model is ~3.5GB
107
+ context ._services .model_manager .install .wait_for_job (job , timeout = 600 )
108
+
109
+ siglip_models = context .models .search_by_attrs (
110
+ name = siglip .name ,
111
+ base = BaseModelType .Any ,
112
+ type = ModelType .SigLIP ,
113
+ )
114
+
115
+ if len (siglip_models ) == 0 :
116
+ context .logger .error ("Error while fetching SigLIP for FLUX Redux" )
117
+ assert len (siglip_models ) == 1
118
+
119
+ return siglip_models [0 ]
0 commit comments