Skip to content

Commit ada483f

Browse files
committed
Various styling and exception type updates
1 parent 0913d06 commit ada483f

File tree

3 files changed

+35
-51
lines changed

3 files changed

+35
-51
lines changed

invokeai/app/invocations/model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
183183
model_key = self.model.key
184184

185185
if not context.models.exists(model_key):
186-
raise Exception(f"Unknown model: {model_key}")
186+
raise ValueError(f"Unknown model: {model_key}")
187187
transformer = self._get_model(context, SubModelType.Transformer)
188188
tokenizer = self._get_model(context, SubModelType.Tokenizer)
189189
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
@@ -203,10 +203,7 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
203203
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
204204
config_path = legacy_config_path.as_posix()
205205
with open(config_path, "r") as stream:
206-
try:
207-
flux_conf = yaml.safe_load(stream)
208-
except:
209-
raise
206+
flux_conf = yaml.safe_load(stream)
210207

211208
return FluxModelLoaderOutput(
212209
transformer=TransformerField(transformer=transformer),

invokeai/app/services/shared/invocation_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def import_local_model(
484484
ModelInstallJob object defining the install job to be used in tracking the job
485485
"""
486486
if not model_path.exists():
487-
raise Exception("Models provided to import_local_model must already exist on disk")
487+
raise ValueError(f"Models provided to import_local_model must already exist on disk at {model_path.as_posix()}")
488488
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, inplace=inplace)
489489

490490
def load_local_model(

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -49,29 +49,24 @@ def _load_model(
4949
config: AnyModelConfig,
5050
submodel_type: Optional[SubModelType] = None,
5151
) -> AnyModel:
52-
if isinstance(config, VAECheckpointConfig):
53-
model_path = Path(config.path)
54-
load_class = AutoEncoder
55-
legacy_config_path = app_config.legacy_conf_path / config.config_path
56-
config_path = legacy_config_path.as_posix()
57-
with open(config_path, "r") as stream:
58-
try:
59-
flux_conf = yaml.safe_load(stream)
60-
except:
61-
raise
62-
63-
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
64-
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
65-
params = AutoEncoderParams(**filtered_data)
66-
67-
with SilenceWarnings():
68-
model = load_class(params)
69-
sd = load_file(model_path)
70-
model.load_state_dict(sd, strict=False, assign=True)
71-
72-
return model
73-
else:
74-
return super()._load_model(config, submodel_type)
52+
if not isinstance(config, VAECheckpointConfig):
53+
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
54+
model_path = Path(config.path)
55+
legacy_config_path = app_config.legacy_conf_path / config.config_path
56+
config_path = legacy_config_path.as_posix()
57+
with open(config_path, "r") as stream:
58+
flux_conf = yaml.safe_load(stream)
59+
60+
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
61+
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
62+
params = AutoEncoderParams(**filtered_data)
63+
64+
with SilenceWarnings():
65+
model = AutoEncoder(params)
66+
sd = load_file(model_path)
67+
model.load_state_dict(sd, strict=False, assign=True)
68+
69+
return model
7570

7671

7772
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
@@ -84,15 +79,15 @@ def _load_model(
8479
submodel_type: Optional[SubModelType] = None,
8580
) -> AnyModel:
8681
if not isinstance(config, CLIPEmbedDiffusersConfig):
87-
raise Exception("Only CLIPEmbedDiffusersConfig models are currently supported here.")
82+
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
8883

8984
match submodel_type:
9085
case SubModelType.Tokenizer:
9186
return CLIPTokenizer.from_pretrained(config.path, max_length=77)
9287
case SubModelType.TextEncoder:
9388
return CLIPTextModel.from_pretrained(config.path)
9489

95-
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
90+
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
9691

9792

9893
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
@@ -105,15 +100,15 @@ def _load_model(
105100
submodel_type: Optional[SubModelType] = None,
106101
) -> AnyModel:
107102
if not isinstance(config, T5Encoder8bConfig):
108-
raise Exception("Only T5Encoder8bConfig models are currently supported here.")
103+
raise ValueError("Only T5Encoder8bConfig models are currently supported here.")
109104

110105
match submodel_type:
111106
case SubModelType.Tokenizer2:
112107
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
113108
case SubModelType.TextEncoder2:
114109
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
115110

116-
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
111+
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
117112

118113

119114
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
@@ -126,7 +121,7 @@ def _load_model(
126121
submodel_type: Optional[SubModelType] = None,
127122
) -> AnyModel:
128123
if not isinstance(config, T5EncoderConfig):
129-
raise Exception("Only T5EncoderConfig models are currently supported here.")
124+
raise ValueError("Only T5EncoderConfig models are currently supported here.")
130125

131126
match submodel_type:
132127
case SubModelType.Tokenizer2:
@@ -136,7 +131,7 @@ def _load_model(
136131
Path(config.path) / "text_encoder_2"
137132
) # TODO: Fix hf subfolder install
138133

139-
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
134+
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
140135

141136

142137
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
@@ -149,36 +144,32 @@ def _load_model(
149144
submodel_type: Optional[SubModelType] = None,
150145
) -> AnyModel:
151146
if not isinstance(config, CheckpointConfigBase):
152-
raise Exception("Only CheckpointConfigBase models are currently supported here.")
147+
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
153148
legacy_config_path = app_config.legacy_conf_path / config.config_path
154149
config_path = legacy_config_path.as_posix()
155150
with open(config_path, "r") as stream:
156-
try:
157-
flux_conf = yaml.safe_load(stream)
158-
except:
159-
raise
151+
flux_conf = yaml.safe_load(stream)
160152

161153
match submodel_type:
162154
case SubModelType.Transformer:
163155
return self._load_from_singlefile(config, flux_conf)
164156

165-
raise Exception("Only Transformer submodels are currently supported.")
157+
raise ValueError("Only Transformer submodels are currently supported.")
166158

167159
def _load_from_singlefile(
168160
self,
169161
config: AnyModelConfig,
170162
flux_conf: Any,
171163
) -> AnyModel:
172164
assert isinstance(config, MainCheckpointConfig)
173-
load_class = Flux
174165
params = None
175166
model_path = Path(config.path)
176167
dataclass_fields = {f.name for f in fields(FluxParams)}
177168
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
178169
params = FluxParams(**filtered_data)
179170

180171
with SilenceWarnings():
181-
model = load_class(params)
172+
model = Flux(params)
182173
sd = load_file(model_path)
183174
model.load_state_dict(sd, strict=False, assign=True)
184175
return model
@@ -194,28 +185,24 @@ def _load_model(
194185
submodel_type: Optional[SubModelType] = None,
195186
) -> AnyModel:
196187
if not isinstance(config, CheckpointConfigBase):
197-
raise Exception("Only CheckpointConfigBase models are currently supported here.")
188+
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
198189
legacy_config_path = app_config.legacy_conf_path / config.config_path
199190
config_path = legacy_config_path.as_posix()
200191
with open(config_path, "r") as stream:
201-
try:
202-
flux_conf = yaml.safe_load(stream)
203-
except:
204-
raise
192+
flux_conf = yaml.safe_load(stream)
205193

206194
match submodel_type:
207195
case SubModelType.Transformer:
208196
return self._load_from_singlefile(config, flux_conf)
209197

210-
raise Exception("Only Transformer submodels are currently supported.")
198+
raise ValueError("Only Transformer submodels are currently supported.")
211199

212200
def _load_from_singlefile(
213201
self,
214202
config: AnyModelConfig,
215203
flux_conf: Any,
216204
) -> AnyModel:
217205
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
218-
load_class = Flux
219206
params = None
220207
model_path = Path(config.path)
221208
dataclass_fields = {f.name for f in fields(FluxParams)}
@@ -224,7 +211,7 @@ def _load_from_singlefile(
224211

225212
with SilenceWarnings():
226213
with accelerate.init_empty_weights():
227-
model = load_class(params)
214+
model = Flux(params)
228215
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
229216
sd = load_file(model_path)
230217
model.load_state_dict(sd, strict=False, assign=True)

0 commit comments

Comments
 (0)