@@ -256,9 +256,12 @@ def _maybe_map_weights(
256
256
# Extract weights from policy module using merge_and_unload for LLMs
257
257
if not hasattr (server_weights , "model" ):
258
258
raise ValueError ("TensorDictModuleBase must have a 'model' attribute" )
259
- if not hasattr (server_weights .model , "merge_and_unload" ):
260
- raise ValueError ("Model must have a 'merge_and_unload' method" )
261
- return TensorDict (server_weights .model .merge_and_unload ().state_dict (), [])
259
+ # Check if it's a LoRA model
260
+ if hasattr (server_weights .model , "merge_and_unload" ):
261
+ state_dict = server_weights .model .merge_and_unload ().state_dict ()
262
+ else :
263
+ state_dict = server_weights .model .state_dict ()
264
+ return TensorDict (state_dict , [])
262
265
elif isinstance (server_weights , TensorDictBase ):
263
266
return server_weights
264
267
elif isinstance (server_weights , dict ):
@@ -281,7 +284,11 @@ def get_model_metadata(
281
284
Returns:
282
285
dict[str, tuple[torch.dtype, torch.Size]]: The model metadata.
283
286
"""
284
- sd = model .model .merge_and_unload ().state_dict ()
287
+ # Check if the model has a LoRA adapter
288
+ if hasattr (model .model , "merge_and_unload" ):
289
+ sd = model .model .merge_and_unload ().state_dict ()
290
+ else :
291
+ sd = model .model .state_dict ()
285
292
model_metadata = {k : (v .dtype , v .shape ) for k , v in sd .items ()}
286
293
return model_metadata
287
294
0 commit comments