@@ -49,29 +49,24 @@ def _load_model(
49
49
config : AnyModelConfig ,
50
50
submodel_type : Optional [SubModelType ] = None ,
51
51
) -> AnyModel :
52
- if isinstance (config , VAECheckpointConfig ):
53
- model_path = Path (config .path )
54
- load_class = AutoEncoder
55
- legacy_config_path = app_config .legacy_conf_path / config .config_path
56
- config_path = legacy_config_path .as_posix ()
57
- with open (config_path , "r" ) as stream :
58
- try :
59
- flux_conf = yaml .safe_load (stream )
60
- except :
61
- raise
62
-
63
- dataclass_fields = {f .name for f in fields (AutoEncoderParams )}
64
- filtered_data = {k : v for k , v in flux_conf ["params" ].items () if k in dataclass_fields }
65
- params = AutoEncoderParams (** filtered_data )
66
-
67
- with SilenceWarnings ():
68
- model = load_class (params )
69
- sd = load_file (model_path )
70
- model .load_state_dict (sd , strict = False , assign = True )
71
-
72
- return model
73
- else :
74
- return super ()._load_model (config , submodel_type )
52
+ if not isinstance (config , VAECheckpointConfig ):
53
+ raise ValueError ("Only VAECheckpointConfig models are currently supported here." )
54
+ model_path = Path (config .path )
55
+ legacy_config_path = app_config .legacy_conf_path / config .config_path
56
+ config_path = legacy_config_path .as_posix ()
57
+ with open (config_path , "r" ) as stream :
58
+ flux_conf = yaml .safe_load (stream )
59
+
60
+ dataclass_fields = {f .name for f in fields (AutoEncoderParams )}
61
+ filtered_data = {k : v for k , v in flux_conf ["params" ].items () if k in dataclass_fields }
62
+ params = AutoEncoderParams (** filtered_data )
63
+
64
+ with SilenceWarnings ():
65
+ model = AutoEncoder (params )
66
+ sd = load_file (model_path )
67
+ model .load_state_dict (sd , strict = False , assign = True )
68
+
69
+ return model
75
70
76
71
77
72
@ModelLoaderRegistry .register (base = BaseModelType .Any , type = ModelType .CLIPEmbed , format = ModelFormat .Diffusers )
@@ -84,15 +79,15 @@ def _load_model(
84
79
submodel_type : Optional [SubModelType ] = None ,
85
80
) -> AnyModel :
86
81
if not isinstance (config , CLIPEmbedDiffusersConfig ):
87
- raise Exception ("Only CLIPEmbedDiffusersConfig models are currently supported here." )
82
+ raise ValueError ("Only CLIPEmbedDiffusersConfig models are currently supported here." )
88
83
89
84
match submodel_type :
90
85
case SubModelType .Tokenizer :
91
86
return CLIPTokenizer .from_pretrained (config .path , max_length = 77 )
92
87
case SubModelType .TextEncoder :
93
88
return CLIPTextModel .from_pretrained (config .path )
94
89
95
- raise Exception ("Only Tokenizer and TextEncoder submodels are currently supported." )
90
+ raise ValueError ("Only Tokenizer and TextEncoder submodels are currently supported." )
96
91
97
92
98
93
@ModelLoaderRegistry .register (base = BaseModelType .Any , type = ModelType .T5Encoder , format = ModelFormat .T5Encoder8b )
@@ -105,15 +100,15 @@ def _load_model(
105
100
submodel_type : Optional [SubModelType ] = None ,
106
101
) -> AnyModel :
107
102
if not isinstance (config , T5Encoder8bConfig ):
108
- raise Exception ("Only T5Encoder8bConfig models are currently supported here." )
103
+ raise ValueError ("Only T5Encoder8bConfig models are currently supported here." )
109
104
110
105
match submodel_type :
111
106
case SubModelType .Tokenizer2 :
112
107
return T5Tokenizer .from_pretrained (Path (config .path ) / "tokenizer_2" , max_length = 512 )
113
108
case SubModelType .TextEncoder2 :
114
109
return FastQuantizedTransformersModel .from_pretrained (Path (config .path ) / "text_encoder_2" )
115
110
116
- raise Exception ("Only Tokenizer and TextEncoder submodels are currently supported." )
111
+ raise ValueError ("Only Tokenizer and TextEncoder submodels are currently supported." )
117
112
118
113
119
114
@ModelLoaderRegistry .register (base = BaseModelType .Any , type = ModelType .T5Encoder , format = ModelFormat .T5Encoder )
@@ -126,7 +121,7 @@ def _load_model(
126
121
submodel_type : Optional [SubModelType ] = None ,
127
122
) -> AnyModel :
128
123
if not isinstance (config , T5EncoderConfig ):
129
- raise Exception ("Only T5EncoderConfig models are currently supported here." )
124
+ raise ValueError ("Only T5EncoderConfig models are currently supported here." )
130
125
131
126
match submodel_type :
132
127
case SubModelType .Tokenizer2 :
@@ -136,7 +131,7 @@ def _load_model(
136
131
Path (config .path ) / "text_encoder_2"
137
132
) # TODO: Fix hf subfolder install
138
133
139
- raise Exception ("Only Tokenizer and TextEncoder submodels are currently supported." )
134
+ raise ValueError ("Only Tokenizer and TextEncoder submodels are currently supported." )
140
135
141
136
142
137
@ModelLoaderRegistry .register (base = BaseModelType .Flux , type = ModelType .Main , format = ModelFormat .Checkpoint )
@@ -149,36 +144,32 @@ def _load_model(
149
144
submodel_type : Optional [SubModelType ] = None ,
150
145
) -> AnyModel :
151
146
if not isinstance (config , CheckpointConfigBase ):
152
- raise Exception ("Only CheckpointConfigBase models are currently supported here." )
147
+ raise ValueError ("Only CheckpointConfigBase models are currently supported here." )
153
148
legacy_config_path = app_config .legacy_conf_path / config .config_path
154
149
config_path = legacy_config_path .as_posix ()
155
150
with open (config_path , "r" ) as stream :
156
- try :
157
- flux_conf = yaml .safe_load (stream )
158
- except :
159
- raise
151
+ flux_conf = yaml .safe_load (stream )
160
152
161
153
match submodel_type :
162
154
case SubModelType .Transformer :
163
155
return self ._load_from_singlefile (config , flux_conf )
164
156
165
- raise Exception ("Only Transformer submodels are currently supported." )
157
+ raise ValueError ("Only Transformer submodels are currently supported." )
166
158
167
159
def _load_from_singlefile (
168
160
self ,
169
161
config : AnyModelConfig ,
170
162
flux_conf : Any ,
171
163
) -> AnyModel :
172
164
assert isinstance (config , MainCheckpointConfig )
173
- load_class = Flux
174
165
params = None
175
166
model_path = Path (config .path )
176
167
dataclass_fields = {f .name for f in fields (FluxParams )}
177
168
filtered_data = {k : v for k , v in flux_conf ["params" ].items () if k in dataclass_fields }
178
169
params = FluxParams (** filtered_data )
179
170
180
171
with SilenceWarnings ():
181
- model = load_class (params )
172
+ model = Flux (params )
182
173
sd = load_file (model_path )
183
174
model .load_state_dict (sd , strict = False , assign = True )
184
175
return model
@@ -194,28 +185,24 @@ def _load_model(
194
185
submodel_type : Optional [SubModelType ] = None ,
195
186
) -> AnyModel :
196
187
if not isinstance (config , CheckpointConfigBase ):
197
- raise Exception ("Only CheckpointConfigBase models are currently supported here." )
188
+ raise ValueError ("Only CheckpointConfigBase models are currently supported here." )
198
189
legacy_config_path = app_config .legacy_conf_path / config .config_path
199
190
config_path = legacy_config_path .as_posix ()
200
191
with open (config_path , "r" ) as stream :
201
- try :
202
- flux_conf = yaml .safe_load (stream )
203
- except :
204
- raise
192
+ flux_conf = yaml .safe_load (stream )
205
193
206
194
match submodel_type :
207
195
case SubModelType .Transformer :
208
196
return self ._load_from_singlefile (config , flux_conf )
209
197
210
- raise Exception ("Only Transformer submodels are currently supported." )
198
+ raise ValueError ("Only Transformer submodels are currently supported." )
211
199
212
200
def _load_from_singlefile (
213
201
self ,
214
202
config : AnyModelConfig ,
215
203
flux_conf : Any ,
216
204
) -> AnyModel :
217
205
assert isinstance (config , MainBnbQuantized4bCheckpointConfig )
218
- load_class = Flux
219
206
params = None
220
207
model_path = Path (config .path )
221
208
dataclass_fields = {f .name for f in fields (FluxParams )}
@@ -224,7 +211,7 @@ def _load_from_singlefile(
224
211
225
212
with SilenceWarnings ():
226
213
with accelerate .init_empty_weights ():
227
- model = load_class (params )
214
+ model = Flux (params )
228
215
model = quantize_model_nf4 (model , modules_to_not_convert = set (), compute_dtype = torch .bfloat16 )
229
216
sd = load_file (model_path )
230
217
model .load_state_dict (sd , strict = False , assign = True )
0 commit comments