@@ -378,14 +378,12 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
378
378
379
379
# ----- model memory compression/decompression pathways ----- #
380
380
381
- def compress_model (self , model : Module , is_meta : bool = False ):
381
+ def compress_model (self , model : Module ):
382
382
"""
383
383
Compress a model in memory. Because the model structure is modified in place,
384
384
this method is more memory-efficient than `self.compress`
385
385
386
386
:param model: model containing parameters to compress
387
- :param is_meta: whether the model is on the meta device, in which case
388
- we do not need move parameters to CPU
389
387
"""
390
388
module_to_scheme = map_module_to_scheme (model )
391
389
sparse_compression_targets : Set [str ] = expand_target_names (
@@ -395,9 +393,13 @@ def compress_model(self, model: Module, is_meta: bool = False):
395
393
)
396
394
397
395
for prefix , module in tqdm (model .named_modules (), desc = "Compressing model" ):
396
+
398
397
if prefix in module_to_scheme or prefix in sparse_compression_targets :
398
+ module_device = get_execution_device (module )
399
+ is_meta = (module_device == torch .device ("meta" ))
400
+
399
401
exec_device = "meta" if is_meta else "cpu"
400
- onloading_device = "meta" if is_meta else get_execution_device ( module )
402
+ onloading_device = "meta" if is_meta else module_device
401
403
402
404
# in the future, support compression on same device
403
405
with align_module_device (module , execution_device = exec_device ):
0 commit comments