@@ -400,7 +400,10 @@ def compress_model(self, model: Module):
400
400
401
401
# in the future, support compression on same device
402
402
with align_module_device (module , execution_device = exec_device ):
403
- state_dict = module .state_dict (prefix = f"{ prefix } ." )
403
+ state_dict = {
404
+ f"{ prefix } .{ name } " : param
405
+ for name , param in module .named_parameters (recurse = False )
406
+ }
404
407
405
408
# quantization first
406
409
if prefix in module_to_scheme :
@@ -421,7 +424,7 @@ def compress_model(self, model: Module):
421
424
422
425
# remove any existing parameters
423
426
offload_device = get_offloaded_device (module )
424
- for name , _ in list (module .named_parameters ()):
427
+ for name , _ in list (module .named_parameters (recurse = False )):
425
428
delete_offload_parameter (module , name )
426
429
427
430
# replace with compressed parameters
@@ -458,7 +461,10 @@ def decompress_model(self, model: Module):
458
461
if prefix in module_to_scheme or prefix in sparse_compression_targets :
459
462
# in the future, support decompression on same device
460
463
with align_module_device (module , execution_device = "cpu" ):
461
- state_dict = module .state_dict (prefix = f"{ prefix } ." )
464
+ state_dict = {
465
+ f"{ prefix } .{ name } " : param
466
+ for name , param in module .named_parameters (recurse = False )
467
+ }
462
468
463
469
# sparsity first
464
470
if prefix in sparse_compression_targets :
@@ -483,7 +489,7 @@ def decompress_model(self, model: Module):
483
489
# remove any existing parameters
484
490
exec_device = get_execution_device (module )
485
491
offload_device = get_offloaded_device (module )
486
- for name , _ in list (module .named_parameters ()):
492
+ for name , _ in list (module .named_parameters (recurse = False )):
487
493
delete_offload_parameter (module , name )
488
494
489
495
# replace with decompressed parameters
0 commit comments