17
17
18
18
import torch
19
19
20
- from ._quant_common.quant_config import local_rank, world_size, HpDtype
20
+ from ._quant_common.quant_config import HpDtype
21
21
from ._core.quant_dequant import QuantDequantBase
22
22
from ._core.scale_handler import update_state_dict_method, ScaleFormat
23
23
from ._core.quantized_func_wrappers import (
26
26
get_quantized_func_wrapper,
27
27
OP_TYPE,
28
28
)
29
+ from .prepare_quant.prepare_model import get_world_size, get_local_rank
29
30
from .utils.logger import logger
30
31
from neural_compressor.common import options
31
32
from neural_compressor.torch.utils import (
@@ -75,7 +76,7 @@ def save_rank_model(model, folder_prefix="", **kwargs):
75
76
"""Save state_dict for model from each rank."""
76
77
# workaround for [SW-199005] [HQT] casted fp8 tensor cannot get data pointer
77
78
cur_accelerator.synchronize()
78
- save_directory = add_rank_suffix(folder_prefix, local_rank, world_size )
79
+ save_directory = add_rank_suffix(folder_prefix, get_local_rank(), get_world_size() )
79
80
os.makedirs(save_directory, exist_ok=True)
80
81
safe_serialization = kwargs.get("safe_serialization", True)
81
82
max_shard_size = kwargs.get("max_shard_size", f"{MAX_FILE_SIZE}GB")
@@ -96,6 +97,8 @@ def gather_state_dict(folder_prefix, file_name, tp_mod_list=[]):
96
97
"""Gather state_dict from files saved by each rank."""
97
98
from safetensors.torch import load_file as safe_load_file
98
99
100
+ world_size = get_world_size()
101
+
99
102
def _is_in_list(name, tp_mod_list):
100
103
for tp_name in tp_mod_list:
101
104
if tp_name in name:
@@ -122,6 +125,7 @@ def _is_in_list(name, tp_mod_list):
122
125
123
126
def clean_rank_files(folder_prefix, file_name=None):
124
127
"""Clean files saved by each rank after gathering."""
128
+ world_size = get_world_size()
125
129
for i in range(world_size): # TODO: assuming tp_size == world_size
126
130
folder_name = add_rank_suffix(folder_prefix, i, world_size)
127
131
if file_name is None:
@@ -375,6 +379,8 @@ def save(model, checkpoint_dir="saved_results", format="huggingface", **kwargs):
375
379
checkpoint_dir (str, optional): path to checkpoint. Defaults to "saved_results".
376
380
format (str, optional): defaults to 'huggingface'.
377
381
"""
382
+ world_size = get_world_size()
383
+ local_rank = get_local_rank()
378
384
format = get_enum_from_format(format)
379
385
model = process_model_for_scalar_scale(model)
380
386
if world_size > 1:
@@ -455,6 +461,7 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
455
461
if model is None:
456
462
with init_empty_weights(include_buffers=False):
457
463
model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=hp_dtype)
464
+ world_size = get_world_size()
458
465
if world_size > 1:
459
466
import deepspeed
460
467
from neural_compressor.torch.utils import get_non_persistent_buffers, load_non_persistent_buffers
@@ -604,8 +611,7 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
604
611
FP8 model.
605
612
"""
606
613
format = get_enum_from_format(format)
607
- global world_size
608
- world_size = kwargs.get("world_size", world_size)
614
+ world_size = kwargs.get("world_size", get_world_size())
609
615
assert format == SaveLoadFormat.HUGGINGFACE, "Currently, only huggingface models are supported."
610
616
assert device in ["hpu", "cpu"], "Currently, only hpu & cpu device is supported for FP8 model."
611
617
@@ -781,7 +787,7 @@ def load_scale_params(model, new_scale_params):
781
787
param.data = new_scale
782
788
783
789
784
- def get_new_rank_state_dict(all_rank_state_dict, model, world_size=world_size , local_rank=local_rank ):
790
+ def get_new_rank_state_dict(all_rank_state_dict, model, world_size=get_world_size() , local_rank=get_local_rank() ):
785
791
"""Get new rank state_dict for world_size.
786
792
787
793
Args:
0 commit comments