diff --git a/neural_compressor/common/utils/__init__.py b/neural_compressor/common/utils/__init__.py index 5b4a8043ff1..0ded3dcc90d 100644 --- a/neural_compressor/common/utils/__init__.py +++ b/neural_compressor/common/utils/__init__.py @@ -28,8 +28,8 @@ VLLM_TP_SIZE = int(os.getenv("VLLM_TP_SIZE", "8")) VLLM_EP_SIZE = int(os.getenv("VLLM_EP_SIZE", VLLM_TP_SIZE)) NUM_EXPERTS_PER_EP_RANK = DEEPSEEK_EXPERTS // VLLM_EP_SIZE # 32 -NUM_EXPERTS_GROUPS = 8 -NUM_EXPERTS_PER_GROUP_PER_RANK = NUM_EXPERTS_PER_EP_RANK // NUM_EXPERTS_GROUPS # 4 +VLLM_MOE_N_SLICE = int(os.getenv("VLLM_MOE_N_SLICE", 8)) +NUM_EXPERTS_PER_GROUP_PER_RANK = NUM_EXPERTS_PER_EP_RANK // VLLM_MOE_N_SLICE # 4 FUSED_MOE_EXPERTS = NUM_EXPERTS_PER_GROUP_PER_RANK # 4 logger.warning_once( @@ -37,7 +37,7 @@ f"INC uses VLLM_TP_SIZE={VLLM_TP_SIZE},\n" f"VLLM_EP_SIZE={VLLM_EP_SIZE},\n" f"NUM_EXPERTS_PER_EP_RANK={NUM_EXPERTS_PER_EP_RANK},\n" - f"NUM_EXPERTS_GROUPS={NUM_EXPERTS_GROUPS},\n" + f"VLLM_MOE_N_SLICE={VLLM_MOE_N_SLICE},\n" f"NUM_EXPERTS_PER_GROUP_PER_RANK={NUM_EXPERTS_PER_GROUP_PER_RANK},\n" f"FUSED_MOE_EXPERTS={FUSED_MOE_EXPERTS}" ) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py index 11992ec8d63..99a31ffe4e4 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py @@ -42,6 +42,14 @@ INFO_INTERVAL = 30 # seconds +def maybe_dequant_original_fp8_weight(mod: torch.nn.Module, param: torch.Tensor): + if param.dtype in [torch.float8_e4m3fn]: + if hasattr(mod, "get_dequant_weights_func"): + dequant_weights_func = mod.get_dequant_weights_func() + if dequant_weights_func is not None: + param = dequant_weights_func(mod) + return param + _mod_types = { "linear": ModuleType(1, ["weight"], 1, False), "matmul": ModuleType(2, [], 1, False), @@ -222,15 +230,17 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, dev "Softmax": ModuleInfo("softmax", PatchedSoftmax), "ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA), "MoeMatmul": ModuleInfo("linear", PatchedMoeMatmul), + "MoeFP8Matmul": ModuleInfo("linear", PatchedMoeFP8Matmul), "ReplicatedLinear": ModuleInfo("linear", PatchedReplicatedLinear), + "VllmMixtureOfExpertsOpFP8": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8), # FIXME (Yi) revert change "FusedMoE": ModuleInfo("linear", PatchedMixtralMoE, False), - "GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock), - "VllmMixtureOfExpertsOp": ( - ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV2) - if os.getenv("LOW_CPU_MEM", "0") == "1" - else ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV1) - ), + # "GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock), + # "VllmMixtureOfExpertsOp": ( + # ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV2) + # if os.getenv("LOW_CPU_MEM", "0") == "1" + # else ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV1) + # ), } diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py b/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py index 0ea9492e6f3..5f229ea8a60 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py @@ -17,6 +17,9 @@ from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator, INCAcceleratorType cur_accelerator = auto_detect_accelerator() +from neural_compressor.torch.utils import environ +from neural_compressor.common.utils import logger + descale_fcn = lambda x, scale: torch.mul(x, scale) scale_fcn = lambda x, scale: torch.div(x, scale) cast_fcn = lambda x, dtype: x.to(dtype=dtype) @@ -106,6 +109,9 @@ def get_fp8_hw_alligned_scales(dtype, device): } def calc_maxabs_scale(xmaxabs, fullscale, backoff=1): + if environ.INC_FORCE_NAIVE_SCALING: + backoff = 1.0 + logger.warning_once(f"Enabled naive scaling, backoff is set to {backoff}") scale = xmaxabs / (fullscale * backoff) return scale diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py b/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py index 9b1db9d0906..ef9bdc1c7b4 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py @@ -22,7 +22,7 @@ import time from .._quant_common.quant_config import MeasureExclude, QuantMode, ScaleMethod, get_hqt_config, set_hqt_config # from ..utils.logger import logger -from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL +from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL, maybe_dequant_original_fp8_weight from .common import * from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator from neural_compressor.torch.algorithms.fp8_quant.model_configs import ( @@ -149,6 +149,10 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N logger.info(f"Patching measure module {name} {mod.__class__}") num_info += 1 set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module + if mod_type == "dynamic_moe" and hasattr(mod, "num_experts"): + # override default number of outputs for dynamic moe + mod_types[mod_type].num_outputs = mod.num_experts+1 + logger.warning(f"Dynamic moe num_outputs set to {mod.num_experts+1}") mod_extra_config = ( init_measure_object( mod, @@ -167,7 +171,10 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N # logger.info(f"Pacthed module pmod: {pmod}") if pmod._mod_extra_config: for param_name in pmod._mod_extra_config.params: + # if torch.distributed.get_rank() == 0: + # import pdb; pdb.set_trace() param = getattr(pmod, param_name) + param = maybe_dequant_original_fp8_weight(pmod.orig_mod, param) if config["measure_on_hpu"]: param = param.to(cur_accelerator.name()) pmod._mod_extra_config.params[param_name].measure(param) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py index 83ae9f7ea27..201172475c5 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py @@ -33,7 +33,7 @@ import time cur_accelerator = auto_detect_accelerator() -from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL +from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL, maybe_dequant_original_fp8_weight @torch.no_grad() @@ -78,9 +78,13 @@ def quantize_params(mod, mod_extra_config): param = getattr(mod, param_name) if param.dtype == torch.float16: param = param.to(torch.bfloat16) + logger.debug(f"Quantizing parameter {param_name} of module {mod.__class__.__name__}") + param = maybe_dequant_original_fp8_weight(mod, param) quantized_param = quantizer(param.to(cur_accelerator.name())) delattr(mod, param_name) setattr(mod, param_name, nn.Parameter(quantized_param)) + # Note: in case of re-quantize the fp8 weights, we need to set `updated_fp8_weight` to True + mod.updated_fp8_weight = True quantized_param = getattr(mod, param_name) quantized_param.requires_grad_(False) cur_accelerator.synchronize() @@ -165,27 +169,38 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name, scale_config, save_file) if not config.cfg["fake_quant"] and mod_default_dict[mod_type_str].should_measure_and_quant: quantize_params(mod, mod_extra_config) - logger.debug(f"patching module {name}") + # logger.debug(f"patching module {name}") patch_module(mod, mod_extra_config, mod_default_dict) name = origin_name patched_modules.append(name) patched_module_types.add(type(mod)) htcore.mark_step() logger.debug("Patched module name: %s", name) + cur_accelerator.synchronize() if save_file: # cache calculated scales save_scales(model, scales_obj, scales_file_format, scale_file + ".npz") save_scales(model, scales_obj, scales_file_format, scale_file + ".json") logger.debug("Patched module types: %s", patched_module_types) logger.debug("Patched modules: %s", patched_modules) logger.debug("Total patched modules: %d", len(patched_modules)) + + show_mem_info("before move all") model = model.to(cur_accelerator.name()) - for _, mod in model.named_modules(): - if hasattr(mod, "post_process"): - mod.post_process() - torch.distributed.barrier() + show_mem_info("after move all") + postporcess_after_convert_(model) + show_mem_info("after post process") convert_fp16_to_bf16(model) + show_mem_info("after convert_fp16_to_bf16") cur_accelerator.synchronize() + show_mem_info("after synchronize") + torch.distributed.barrier() +def postporcess_after_convert_(model): + for _, mod in model.named_modules(): + if hasattr(mod, "post_process"): + mod.post_process() + # Note: It is very important to synchronize after each post_process to avoid OoM. + cur_accelerator.synchronize() def prepare_model_with_dummy_measurement(model, mod_list, scaling_method_name, scale_config): """Aim for loading, replace module with patched module for model on meta device. diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py index 738c12e0ff7..697001ff818 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py @@ -20,7 +20,8 @@ from .scales_method import QuantTensorType from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput from neural_compressor.common import utils as inc_utils - +# from neural_compressor.torch.algorithms.fp8_quant.utils import +from neural_compressor.torch.algorithms.fp8_quant._core.common import maybe_dequant_original_fp8_weight class BaseOpQuantizer: def __init__(self, config, mod, measurement, params, op_type): @@ -94,9 +95,11 @@ def get_scales_module_config(self): input_scales = self.calc_input_scales(num_of_inputs=1) output_measurement = self.measurement.outputs[0] if self.measurement is not None else [] rescaled_weight = self.mod.weight if hasattr(self.mod, 'weight') else None + if rescaled_weight is not None: + rescaled_weight = maybe_dequant_original_fp8_weight(self.mod, rescaled_weight) if self.weight_ich_scale_calc is not None: weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST) - rescaled_weight = torch.div(self.mod.weight, weight_scales_in_ch.reshape([1, -1])) + rescaled_weight = torch.div(rescaled_weight, weight_scales_in_ch.reshape([1, -1])) weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(rescaled_weight, QuantTensorType.CONST) params_config = ( {"weight": weights_scales_out_ch} diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py index b1517c712b1..91d97fb6fda 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py @@ -16,7 +16,8 @@ from .round_scales_function import * from ..common import get_device_type_for_scales from .scales_method import * - +from neural_compressor.torch.utils import environ +from neural_compressor.common.utils import logger class QuantTensorName(Enum): INPUT = auto() @@ -40,6 +41,9 @@ class ScaleValueType(Enum): def parse_rounding_method(config, device_for_scales): round_method = ScaleIdentity() + if environ.INC_FORCE_NAIVE_SCALING: + logger.warning_once("Enabled naive scaling") + return round_method if "single" in config and "hw" in config: round_method = ScaleHwAlignedFixed(device_for_scales) elif "unit" in config: diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index 245b3f08a16..182f14311fc 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -418,7 +418,8 @@ def forward_quant(self, input): def forward_measure(self, input): resolved_input = self.resolve_input(input) measure_input((resolved_input,), observer=self._mod_extra_config.inputs) - output = torch.matmul(resolved_input, self.weight.transpose(-1, -2)) + # output = torch.matmul(resolved_input, self.weight.transpose(-1, -2)) + output = self.orig_mod.quant_method.apply(self.orig_mod, resolved_input) measure_output((output,), self._mod_extra_config.outputs) if self.reduce_results: output = self.collective_func(output) @@ -474,11 +475,20 @@ def forward_quant(self, input): def forward_measure(self, input): measure_input((input,), observer=self._mod_extra_config.inputs) - output = torch.matmul(input, self.weight.transpose(-1, -2)) + output = self.orig_mod.quant_method.apply(self.orig_mod, input) measure_output((output,), self._mod_extra_config.outputs) + output, output_bias = self.add_bias(output) if self.gather_output: output = self.collective_func(output) - return self.post_all_reduce(output) + return output, output_bias + + def add_bias(self, output): + if not self.skip_bias_add: + output = output + self.bias if self.bias is not None else output + output_bias = None + else: + output_bias = self.bias + return output, output_bias def post_all_reduce(self, output): if not self.skip_bias_add: @@ -632,7 +642,16 @@ def extra_repr(self) -> str: get_current_repr(self, "scale_input", "scale_weight"), ) - +class PatchedMoeFP8Matmul(PatchedMoeMatmul): + def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): + super().__init__(mod, parent, mod_extra_config, *args, **kwargs) + # if torch.distributed.get_rank() == 0: + # import pdb; pdb.set_trace() + # torch.distributed.barrier() + # self.block_size = self.orig_mod.block_size + # self.scale_inv_fp8 = self.orig_mod.scale_inv_fp8 + self.get_dequant_weight = self.orig_mod.get_dequant_weight + class PatchedGaudiMixtralSparseMoeBlock(PatchedModuleBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) @@ -724,8 +743,8 @@ def extra_repr(self) -> str: class PatchedVllmMixtureOfExpertsOpV1(PatchedModuleBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) - self.experts_min = self.orig_mod.experts_min - self.experts_max = self.orig_mod.experts_max + self.experts_min = self.orig_mod.experts_min if hasattr(self.orig_mod, "experts_min") else 0 + self.experts_max = self.orig_mod.experts_max if hasattr(self.orig_mod, "experts_max") else 7 if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: self.forward = self.forward_quant self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, self.scale_format) @@ -737,11 +756,18 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): [mod_extra_config.scale.inputs[x] for x in range(1, self.num_experts+1)], self.scale_format, ) - for i in range(self.num_experts): - self.w13_list[i].weight = self.w13_list[i].weight.squeeze().t().contiguous() - self.w2_list[i].weight = self.w2_list[i].weight.squeeze().t().contiguous() + # if torch.distributed.get_rank() == 0: + # import pdb; pdb.set_trace() + # torch.distributed.barrier() + self._post_init_for_quant() + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): self.forward = self.forward_measure + + def _post_init_for_quant(self): + for i in range(self.num_experts): + self.w13_list[i].weight = self.w13_list[i].weight.squeeze().t().contiguous() + self.w2_list[i].weight = self.w2_list[i].weight.squeeze().t().contiguous() def forward_quant(self, hidden_states, @@ -813,6 +839,100 @@ def extra_repr(self) -> str: f"quant_mode:{quant_mode}, {get_current_repr(self, *member_names)}", ) +class PatchedVllmMixtureOfExpertsOpFP8(PatchedVllmMixtureOfExpertsOpV1): + def _post_init_for_quant(self): + pass + + def post_process(self): + # if torch.distributed.get_rank() == 0: + # import pdb; pdb.set_trace() + # torch.distributed.barrier() + for i in range(self.num_experts): + self.w13_list[i].weight = torch.nn.Parameter(self.w13_list[i].weight.squeeze().t().contiguous()) + self.w2_list[i].weight = torch.nn.Parameter(self.w2_list[i].weight.squeeze().t().contiguous()) + htcore.mark_step() + + def forward_measure( + self, + x, + topk_ids, + topk_weights, + moe_n_slice=None, + n_expert_slice=None, + ep_shift=None, + ): + hidden_states = x + measure_input((hidden_states,), observer=self._mod_extra_config.inputs) + # FIXME: (Yi) Assume moe_n_slice is 1, remove it? + # assert moe_n_slice == 1, f"moe_n_slice is {moe_n_slice}, expected 1" + min_expert = self.experts_min + max_expert = self.experts_max + w13_list_slice = [] + w2_list_slice = [] + for j in range(self.num_experts): + w13_list_slice.append(self.w13_list[j].get_dequant_weight()) + w2_list_slice.append(self.w2_list[j].get_dequant_weight()) + + output, intermidiate_amax = torch.ops.hpu.mixture_of_experts.fp8_measurement_fused_weights( + hidden_states=x, + expert_routing_table=topk_ids.to(torch.int64), + router_weights=topk_weights.to(x.dtype), + w12=w13_list_slice, + w3=w2_list_slice, + permuted_weights=True, + activation="silu", + experts_min=min_expert, + experts_max=max_expert, + measurement_mode=True, # <============= + ) + output_measure_list = [output] + # if torch.distributed.get_rank() == 0: + # import pdb; pdb.set_trace() + # torch.distributed.barrier() + for i in range(self.num_experts): + output_measure_list.append(intermidiate_amax[i]) + measure_output(output_measure_list, self._mod_extra_config.outputs) + return output + + def forward_quant( + self, + x, + topk_ids, + topk_weights, + moe_n_slice=None, + n_expert_slice=None, + ep_shift=None, + ): + hidden_states = x + expert_routing_table = topk_ids.to(torch.int64) + router_weights = topk_weights.to(x.dtype) + permuted_weights = True + activation = "silu" + # if torch.distributed.get_rank() == 0: + # import pdb; pdb.set_trace() + # torch.distributed.barrier() + experts_range = range(self.num_experts) + w1_list = [self.w13_list[i].weight for i in experts_range] + w2_list = [self.w2_list[i].weight for i in experts_range] + scale_w1 = [self.w13_list[i].scale_weight for i in experts_range] + scale_w2 = [self.w2_list[i].scale_weight for i in experts_range] + qinput = self.quant_input(hidden_states) + output = self.dynamic_moe_op( + hidden_states=qinput, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + d_scale_w12=scale_w1, + d_scale_w3=scale_w2, + d_scale_hidden_states=self.scale_input, + d_scale_intermediate_hidden_states=self.scale_intermediate, + permuted_weights=permuted_weights, + activation=activation, + experts_min=self.experts_min, + experts_max=self.experts_max, + ) + return output class PatchedVllmMixtureOfExpertsOpV2(PatchedVllmMixtureOfExpertsOpV1): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): @@ -954,10 +1074,10 @@ def forward_qdq(self, input, *args, **kwargs): output_cache = self.orig_mod(qinput, *args, **kwargs) return output_cache - # def forward_quant(self, input, *args, **kwargs): - # qinput = self.quant_input(input) - # output_cache = self.orig_mod(qinput, *args, **kwargs) - # return self.dequant_output(output_cache) + def forward_quant(self, input, *args, **kwargs): + qinput = self.quant_input(input) + output_cache = self.orig_mod(qinput, *args, **kwargs) + return self.dequant_output(output_cache) def forward_measure(self, input, *args, **kwargs): measure_input((input, ), self._mod_extra_config.inputs) @@ -965,22 +1085,12 @@ def forward_measure(self, input, *args, **kwargs): measure_output((output_cache, ), self._mod_extra_config.outputs) return output_cache - # def fetch_from_cache(self, cache, blocks, permutations=None): - # # quant_cache = self.quant_input(cache) - # quant_cache = cache - # if permutations: - # output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations) - # for i in range(len(output_cache)): - # output_cache[i] = self.dequant_output(output_cache[i]) - # return output_cache - # output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks) - # return self.dequant_output(output_cache) - - def forward_quant(self, input, *args, **kwargs): - qinput = self.quant_input(input) - return self.orig_mod(qinput, *args, **kwargs) - - def fetch_from_cache(self, quant_cache, blocks, permutations=None): + def fetch_from_cache(self, cache, blocks, permutations=None): + # TODO: Remove this workaround in next release [SW-221595] + if cache.dtype != self.lp_dtype: + quant_cache = self.quant_input(cache) + else: + quant_cache = cache if permutations: output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations) for i in range(len(output_cache)): @@ -988,7 +1098,7 @@ def fetch_from_cache(self, quant_cache, blocks, permutations=None): return output_cache output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks) return self.dequant_output(output_cache) - + def extra_repr(self) -> str: return f"PatchedVLLMKVCache" diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py index e4ebdefaa11..88a703e7c6f 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py @@ -24,7 +24,7 @@ import torch from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator, INCAcceleratorType -from ..utils.logger import logger +from neural_compressor.torch.utils import logger try: world_size = torch.distributed.get_world_size() @@ -225,6 +225,26 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg: else: measured_global_config[keys] = custom_config[keys] + INC_MEASUREMENT_DUMP_PATH_PREFIX = os.getenv("INC_MEASUREMENT_DUMP_PATH_PREFIX", None) + if INC_MEASUREMENT_DUMP_PATH_PREFIX is not None: + dump_stats_path = os.path.join(INC_MEASUREMENT_DUMP_PATH_PREFIX, measured_global_config["dump_stats_path"]) + measured_global_config["dump_stats_path"] = dump_stats_path + logger.info( + f"INC_MEASUREMENT_DUMP_PATH_PREFIX is set to {INC_MEASUREMENT_DUMP_PATH_PREFIX}, dump_stats_path is set to {dump_stats_path}" + ) + # check if the directory exists + + dir_path = os.path.dirname(measured_global_config["dump_stats_path"]) + abs_path = os.path.abspath(dir_path) + if not (os.path.exists(dir_path) or os.path.exists(abs_path)): + raise ValueError( + ( + f"The measurement dump directory '{dir_path}' does not exist," + f" the path is determined by the environment variable INC_MEASUREMENT_DUMP_PATH_PREFIX" + f" and the dump_stats_path in the quantization config file." + ) + ) + # If seperate_measure_files is True (default value), then it is assumed that there are multiple distinct measure and scale files # and they are stored in / loaded from paths with the correct index as a suffix. Else, only one is searched for. measured_global_config["local_rank"] = ( diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index e60dcc7ad88..7c40af92710 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -32,6 +32,10 @@ world_size = int(os.getenv("WORLD_SIZE", "-1")) + +INC_FORCE_NAIVE_SCALING = os.getenv("INC_FORCE_NAIVE_SCALING", "0").lower() in ["1", "true"] + + ################ Check imported sys.module first to decide behavior ################# def is_ipex_imported() -> bool: """Check whether intel_extension_for_pytorch is imported.""" @@ -235,15 +239,18 @@ def is_tbb_available(): # pragma: no cover return False return True -def show_mem_info(loglevel="info"): +def show_mem_info(msg="", loglevel="info"): hpu_mem_mb = get_used_hpu_mem_MB() from neural_compressor.common.utils import logger show_fn = getattr(logger, loglevel) rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1 - show_fn(f"[Rank {rank}] Used HPU memory: {hpu_mem_mb // 1000} GB {hpu_mem_mb % 1000} MB") + # show_fn(f"[Rank {rank}] Used HPU memory: {hpu_mem_mb // 1000} GB {hpu_mem_mb % 1000} MB") cpu_mem_mb = get_used_cpu_mem_MB() - show_fn(f"[Rank {rank}] Used CPU memory: {cpu_mem_mb // 1000} GB {cpu_mem_mb % 1000} MB") - + # show_fn(f"[Rank {rank}] Used CPU memory: {cpu_mem_mb // 1000} GB {cpu_mem_mb % 1000} MB") + show_fn( + f"[Rank {rank}] {msg}, HPU: {hpu_mem_mb // 1000} GB {hpu_mem_mb % 1000:.2f} MB; CPU: {cpu_mem_mb // 1000} GB {cpu_mem_mb % 1000:.2f} MB" + ) + def get_used_hpu_mem_MB(): """Get HPU used memory: MiB.""" diff --git a/setup.py b/setup.py index d0be2291d6f..5d2fe1f9f4a 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,12 @@ def get_build_version(): ], ), "package_data": {"": ["*.json"]}, - "install_requires": fetch_requirements("requirements_pt.txt"), + # FIXME: (Yi) force install neural_compressor_pt + # "install_requires": fetch_requirements("requirements_pt.txt"), + "install_requires": fetch_requirements("requirements.txt"), + "extras_require": { + "pt": fetch_requirements("requirements_pt.txt"), + } }, # 3.x tf binary build config, pip install neural-compressor-tf, install 3.x TensorFlow API. "neural_compressor_tf": { @@ -102,7 +107,9 @@ def get_build_version(): # https://github.com/pytorch/pytorch/pull/114662 ext_modules = [] cmdclass = {} - + + + if "pt" in sys.argv: sys.argv.remove("pt") cfg_key = "neural_compressor_pt" @@ -110,7 +117,9 @@ def get_build_version(): if "tf" in sys.argv: sys.argv.remove("tf") cfg_key = "neural_compressor_tf" - + # FIXME: (Yi) force install neural_compressor_pt + print(f"Forcing install neural_compressor_pt") + cfg_key = "neural_compressor_pt" project_name = PKG_INSTALL_CFG[cfg_key].get("project_name") include_packages = PKG_INSTALL_CFG[cfg_key].get("include_packages") or {} package_data = PKG_INSTALL_CFG[cfg_key].get("package_data") or {}