@@ -64,7 +64,7 @@ def _load_model(
64
64
with SilenceWarnings ():
65
65
model = AutoEncoder (params )
66
66
sd = load_file (model_path )
67
- model .load_state_dict (sd , strict = False , assign = True )
67
+ model .load_state_dict (sd , assign = True )
68
68
69
69
return model
70
70
@@ -83,11 +83,11 @@ def _load_model(
83
83
84
84
match submodel_type :
85
85
case SubModelType .Tokenizer :
86
- return CLIPTokenizer .from_pretrained (config .path , max_length = 77 )
86
+ return CLIPTokenizer .from_pretrained (config .path )
87
87
case SubModelType .TextEncoder :
88
88
return CLIPTextModel .from_pretrained (config .path )
89
89
90
- raise ValueError ("Only Tokenizer and TextEncoder submodels are currently supported." )
90
+ raise ValueError (f "Only Tokenizer and TextEncoder submodels are currently supported. Received: { submodel_type . value if submodel_type else 'None' } " )
91
91
92
92
93
93
@ModelLoaderRegistry .register (base = BaseModelType .Any , type = ModelType .T5Encoder , format = ModelFormat .T5Encoder8b )
@@ -108,7 +108,7 @@ def _load_model(
108
108
case SubModelType .TextEncoder2 :
109
109
return FastQuantizedTransformersModel .from_pretrained (Path (config .path ) / "text_encoder_2" )
110
110
111
- raise ValueError ("Only Tokenizer and TextEncoder submodels are currently supported." )
111
+ raise ValueError (f "Only Tokenizer and TextEncoder submodels are currently supported. Received: { submodel_type . value if submodel_type else 'None' } " )
112
112
113
113
114
114
@ModelLoaderRegistry .register (base = BaseModelType .Any , type = ModelType .T5Encoder , format = ModelFormat .T5Encoder )
@@ -131,7 +131,7 @@ def _load_model(
131
131
Path (config .path ) / "text_encoder_2"
132
132
) # TODO: Fix hf subfolder install
133
133
134
- raise ValueError ("Only Tokenizer and TextEncoder submodels are currently supported." )
134
+ raise ValueError (f "Only Tokenizer and TextEncoder submodels are currently supported. Received: { submodel_type . value if submodel_type else 'None' } " )
135
135
136
136
137
137
@ModelLoaderRegistry .register (base = BaseModelType .Flux , type = ModelType .Main , format = ModelFormat .Checkpoint )
@@ -154,15 +154,14 @@ def _load_model(
154
154
case SubModelType .Transformer :
155
155
return self ._load_from_singlefile (config , flux_conf )
156
156
157
- raise ValueError ("Only Transformer submodels are currently supported." )
157
+ raise ValueError (f "Only Transformer submodels are currently supported. Received: { submodel_type . value if submodel_type else 'None' } " )
158
158
159
159
def _load_from_singlefile (
160
160
self ,
161
161
config : AnyModelConfig ,
162
162
flux_conf : Any ,
163
163
) -> AnyModel :
164
164
assert isinstance (config , MainCheckpointConfig )
165
- params = None
166
165
model_path = Path (config .path )
167
166
dataclass_fields = {f .name for f in fields (FluxParams )}
168
167
filtered_data = {k : v for k , v in flux_conf ["params" ].items () if k in dataclass_fields }
@@ -171,7 +170,7 @@ def _load_from_singlefile(
171
170
with SilenceWarnings ():
172
171
model = Flux (params )
173
172
sd = load_file (model_path )
174
- model .load_state_dict (sd , strict = False , assign = True )
173
+ model .load_state_dict (sd , assign = True )
175
174
return model
176
175
177
176
@@ -195,15 +194,14 @@ def _load_model(
195
194
case SubModelType .Transformer :
196
195
return self ._load_from_singlefile (config , flux_conf )
197
196
198
- raise ValueError ("Only Transformer submodels are currently supported." )
197
+ raise ValueError (f "Only Transformer submodels are currently supported. Received: { submodel_type . value if submodel_type else 'None' } " )
199
198
200
199
def _load_from_singlefile (
201
200
self ,
202
201
config : AnyModelConfig ,
203
202
flux_conf : Any ,
204
203
) -> AnyModel :
205
204
assert isinstance (config , MainBnbQuantized4bCheckpointConfig )
206
- params = None
207
205
model_path = Path (config .path )
208
206
dataclass_fields = {f .name for f in fields (FluxParams )}
209
207
filtered_data = {k : v for k , v in flux_conf ["params" ].items () if k in dataclass_fields }
@@ -214,5 +212,5 @@ def _load_from_singlefile(
214
212
model = Flux (params )
215
213
model = quantize_model_nf4 (model , modules_to_not_convert = set (), compute_dtype = torch .bfloat16 )
216
214
sd = load_file (model_path )
217
- model .load_state_dict (sd , strict = False , assign = True )
215
+ model .load_state_dict (sd , assign = True )
218
216
return model
0 commit comments