Skip to content

Commit 0899978

Browse files
xin3hekdamaszk
andauthored
[SW-234750] Fix reading distributed data in quant_config (#284) (#2255)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
1 parent a3da6e8 commit 0899978

File tree

4 files changed

+47
-16
lines changed

4 files changed

+47
-16
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,10 @@
2525

2626
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator, INCAcceleratorType
2727
from ..utils.logger import logger
28+
from ..prepare_quant.prepare_model import get_world_size, get_local_rank
2829
from .._core.scale_methods.scale_method_parser import parse_scale_method, validate_and_populate_scale_method, convert_scale_method_strings_to_enum
2930
from .._core.scale_methods.scale_method_config import get_scale_method_from_config, check_scale_method_fields, ScaleMethodString, CfgStr, ScaleGranularity, ScaleValueType, ScaleRoundMethod
3031

31-
try:
32-
world_size = torch.distributed.get_world_size()
33-
local_rank = torch.distributed.get_rank()
34-
except:
35-
local_rank = int(os.getenv("LOCAL_RANK", "-1"))
36-
world_size = int(os.getenv("WORLD_SIZE", "-1"))
3732

3833
class QuantMode(Enum):
3934
NONE = 0
@@ -153,6 +148,8 @@ class Fp8cfg:
153148
cfg: Mapping[str, Any]
154149

155150
def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
151+
world_size = get_world_size()
152+
local_rank = get_local_rank()
156153
measured_global_config = {
157154
"dump_stats_path": "stats",
158155
"fp8_config": torch.float8_e4m3fn, # The parameters of the chosen Quantization methed

neural_compressor/torch/algorithms/fp8_quant/prepare_quant/prepare_model.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,34 @@
1313
# limitations under the License.
1414

1515
import os
16+
import torch
1617
from typing import Optional
1718

18-
from .._core.save_measure import save_measurements
19-
from .._core.utils import prepare_model
20-
from .._quant_common.quant_config import Fp8cfg, _read_config_from_file, set_hqt_config
2119

20+
_world_size = -1
21+
_local_rank = -1
22+
23+
24+
def get_world_size():
25+
global _world_size
26+
if _world_size == -1:
27+
if torch.distributed.is_initialized():
28+
_world_size = torch.distributed.get_world_size()
29+
return _world_size
30+
31+
32+
def get_local_rank():
33+
global _local_rank
34+
if _local_rank == -1:
35+
if torch.distributed.is_initialized():
36+
_local_rank = torch.distributed.get_rank()
37+
return _local_rank
38+
39+
40+
def _prep_model_with_predefined_config(model, *, config):
41+
from .._core.utils import prepare_model
42+
from .._quant_common.quant_config import set_hqt_config
2243

23-
def _prep_model_with_predefined_config(model, *, config: Fp8cfg):
2444
set_hqt_config(model, config)
2545
prepare_model(model)
2646

@@ -31,6 +51,8 @@ def prep_model(model, config_path: Optional[str] = None):
3151
If `config_path` is not given or `None`,
3252
instead perform the legacy behavior of checking for env variable `QUANT_CONFIG`.
3353
"""
54+
from .._quant_common.quant_config import Fp8cfg, _read_config_from_file
55+
3456
if config_path is None:
3557
config_path = os.getenv("QUANT_CONFIG")
3658
if config_path is None:
@@ -44,4 +66,6 @@ def prep_model(model, config_path: Optional[str] = None):
4466

4567

4668
def finish_measurements(model):
69+
from .._core.save_measure import save_measurements
70+
4771
save_measurements(model)

neural_compressor/torch/algorithms/fp8_quant/save_load.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919

20-
from ._quant_common.quant_config import local_rank, world_size, HpDtype
20+
from ._quant_common.quant_config import HpDtype
2121
from ._core.quant_dequant import QuantDequantBase
2222
from ._core.scale_handler import update_state_dict_method, ScaleFormat
2323
from ._core.quantized_func_wrappers import (
@@ -26,6 +26,7 @@
2626
get_quantized_func_wrapper,
2727
OP_TYPE,
2828
)
29+
from .prepare_quant.prepare_model import get_world_size, get_local_rank
2930
from .utils.logger import logger
3031
from neural_compressor.common import options
3132
from neural_compressor.torch.utils import (
@@ -75,7 +76,7 @@ def save_rank_model(model, folder_prefix="", **kwargs):
7576
"""Save state_dict for model from each rank."""
7677
# workaround for [SW-199005] [HQT] casted fp8 tensor cannot get data pointer
7778
cur_accelerator.synchronize()
78-
save_directory = add_rank_suffix(folder_prefix, local_rank, world_size)
79+
save_directory = add_rank_suffix(folder_prefix, get_local_rank(), get_world_size())
7980
os.makedirs(save_directory, exist_ok=True)
8081
safe_serialization = kwargs.get("safe_serialization", True)
8182
max_shard_size = kwargs.get("max_shard_size", f"{MAX_FILE_SIZE}GB")
@@ -96,6 +97,8 @@ def gather_state_dict(folder_prefix, file_name, tp_mod_list=[]):
9697
"""Gather state_dict from files saved by each rank."""
9798
from safetensors.torch import load_file as safe_load_file
9899

100+
world_size = get_world_size()
101+
99102
def _is_in_list(name, tp_mod_list):
100103
for tp_name in tp_mod_list:
101104
if tp_name in name:
@@ -122,6 +125,7 @@ def _is_in_list(name, tp_mod_list):
122125

123126
def clean_rank_files(folder_prefix, file_name=None):
124127
"""Clean files saved by each rank after gathering."""
128+
world_size = get_world_size()
125129
for i in range(world_size): # TODO: assuming tp_size == world_size
126130
folder_name = add_rank_suffix(folder_prefix, i, world_size)
127131
if file_name is None:
@@ -375,6 +379,8 @@ def save(model, checkpoint_dir="saved_results", format="huggingface", **kwargs):
375379
checkpoint_dir (str, optional): path to checkpoint. Defaults to "saved_results".
376380
format (str, optional): defaults to 'huggingface'.
377381
"""
382+
world_size = get_world_size()
383+
local_rank = get_local_rank()
378384
format = get_enum_from_format(format)
379385
model = process_model_for_scalar_scale(model)
380386
if world_size > 1:
@@ -455,6 +461,7 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
455461
if model is None:
456462
with init_empty_weights(include_buffers=False):
457463
model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=hp_dtype)
464+
world_size = get_world_size()
458465
if world_size > 1:
459466
import deepspeed
460467
from neural_compressor.torch.utils import get_non_persistent_buffers, load_non_persistent_buffers
@@ -604,8 +611,7 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
604611
FP8 model.
605612
"""
606613
format = get_enum_from_format(format)
607-
global world_size
608-
world_size = kwargs.get("world_size", world_size)
614+
world_size = kwargs.get("world_size", get_world_size())
609615
assert format == SaveLoadFormat.HUGGINGFACE, "Currently, only huggingface models are supported."
610616
assert device in ["hpu", "cpu"], "Currently, only hpu & cpu device is supported for FP8 model."
611617

@@ -781,7 +787,7 @@ def load_scale_params(model, new_scale_params):
781787
param.data = new_scale
782788

783789

784-
def get_new_rank_state_dict(all_rank_state_dict, model, world_size=world_size, local_rank=local_rank):
790+
def get_new_rank_state_dict(all_rank_state_dict, model, world_size=get_world_size(), local_rank=get_local_rank()):
785791
"""Get new rank state_dict for world_size.
786792

787793
Args:

test/3x/torch/quantization/fp8_quant/test_save_load.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import transformers
66

7-
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import local_rank, world_size
7+
from neural_compressor.torch.algorithms.fp8_quant.prepare_quant.prepare_model import get_world_size, get_local_rank
88
from neural_compressor.torch.quantization import FP8Config, convert, load, prepare, save
99
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedLinear
1010
from neural_compressor.torch.utils import get_used_hpu_mem_MB
@@ -45,6 +45,7 @@ def calib_func(model):
4545

4646
def test_save_vllm_compatible_model():
4747
name = "Qwen/Qwen2-0.5B-Instruct"
48+
world_size = get_world_size()
4849
if world_size > 0:
4950
# Do not use random weights since multi-processes will get different weights for Embedding
5051
model = transformers.AutoModelForCausalLM.from_pretrained(name)
@@ -77,6 +78,7 @@ def test_save_vllm_compatible_model():
7778

7879
@pytest.mark.skip(reason="[SW-226589] Skip this test since the model was updated")
7980
def test_load_model_provided_by_neuralmagic():
81+
world_size = get_world_size()
8082
model_name_or_path = "neuralmagic/Qwen2-0.5B-Instruct-FP8"
8183
hpu_mem0 = get_used_hpu_mem_MB()
8284
model = load(model_name_or_path, format="huggingface", device="hpu")
@@ -117,6 +119,8 @@ def init_model(world_size):
117119
@torch.no_grad()
118120
@pytest.mark.parametrize("scale_method", ["maxabs_hw", "act_maxabs_hw_weights_pcs_maxabs_pow2"])
119121
def test_default_save_load(scale_method):
122+
world_size = get_world_size()
123+
local_rank = get_local_rank()
120124
example_inputs = torch.tensor([[10, 20]], dtype=torch.long).to("hpu")
121125
model = init_model(world_size)
122126
# The default value of model.generation_config.max_length in transformers is 20

0 commit comments

Comments
 (0)