Skip to content

Commit 0916ca5

Browse files
committed
removed is_meta input
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent 53b63b1 commit 0916ca5

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,12 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
378378

379379
# ----- model memory compression/decompression pathways ----- #
380380

381-
def compress_model(self, model: Module, is_meta: bool = False):
381+
def compress_model(self, model: Module):
382382
"""
383383
Compress a model in memory. Because the model structure is modified in place,
384384
this method is more memory-efficient than `self.compress`
385385
386386
:param model: model containing parameters to compress
387-
:param is_meta: whether the model is on the meta device, in which case
388-
we do not need move parameters to CPU
389387
"""
390388
module_to_scheme = map_module_to_scheme(model)
391389
sparse_compression_targets: Set[str] = expand_target_names(
@@ -395,9 +393,13 @@ def compress_model(self, model: Module, is_meta: bool = False):
395393
)
396394

397395
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
396+
398397
if prefix in module_to_scheme or prefix in sparse_compression_targets:
398+
module_device = get_execution_device(module)
399+
is_meta = (module_device == torch.device("meta"))
400+
399401
exec_device = "meta" if is_meta else "cpu"
400-
onloading_device = "meta" if is_meta else get_execution_device(module)
402+
onloading_device = "meta" if is_meta else module_device
401403

402404
# in the future, support compression on same device
403405
with align_module_device(module, execution_device=exec_device):

0 commit comments

Comments
 (0)