Skip to content

R1 woq #2148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 28 commits into
base: dev/ds_r1
Choose a base branch
from
Draft

R1 woq #2148

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions neural_compressor/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@
VLLM_TP_SIZE = int(os.getenv("VLLM_TP_SIZE", "8"))
VLLM_EP_SIZE = int(os.getenv("VLLM_EP_SIZE", VLLM_TP_SIZE))
NUM_EXPERTS_PER_EP_RANK = DEEPSEEK_EXPERTS // VLLM_EP_SIZE # 32
NUM_EXPERTS_GROUPS = 8
NUM_EXPERTS_PER_GROUP_PER_RANK = NUM_EXPERTS_PER_EP_RANK // NUM_EXPERTS_GROUPS # 4
VLLM_MOE_N_SLICE = int(os.getenv("VLLM_MOE_N_SLICE", 8))
NUM_EXPERTS_PER_GROUP_PER_RANK = NUM_EXPERTS_PER_EP_RANK // VLLM_MOE_N_SLICE # 4
FUSED_MOE_EXPERTS = NUM_EXPERTS_PER_GROUP_PER_RANK # 4

logger.warning_once(
(
f"INC uses VLLM_TP_SIZE={VLLM_TP_SIZE},\n"
f"VLLM_EP_SIZE={VLLM_EP_SIZE},\n"
f"NUM_EXPERTS_PER_EP_RANK={NUM_EXPERTS_PER_EP_RANK},\n"
f"NUM_EXPERTS_GROUPS={NUM_EXPERTS_GROUPS},\n"
f"VLLM_MOE_N_SLICE={VLLM_MOE_N_SLICE},\n"
f"NUM_EXPERTS_PER_GROUP_PER_RANK={NUM_EXPERTS_PER_GROUP_PER_RANK},\n"
f"FUSED_MOE_EXPERTS={FUSED_MOE_EXPERTS}"
)
Expand Down
22 changes: 16 additions & 6 deletions neural_compressor/torch/algorithms/fp8_quant/_core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@

INFO_INTERVAL = 30 # seconds

def maybe_dequant_original_fp8_weight(mod: torch.nn.Module, param: torch.Tensor):
if param.dtype in [torch.float8_e4m3fn]:
if hasattr(mod, "get_dequant_weights_func"):
dequant_weights_func = mod.get_dequant_weights_func()
if dequant_weights_func is not None:
param = dequant_weights_func(mod)
return param

_mod_types = {
"linear": ModuleType(1, ["weight"], 1, False),
"matmul": ModuleType(2, [], 1, False),
Expand Down Expand Up @@ -222,15 +230,17 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, dev
"Softmax": ModuleInfo("softmax", PatchedSoftmax),
"ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA),
"MoeMatmul": ModuleInfo("linear", PatchedMoeMatmul),
"MoeFP8Matmul": ModuleInfo("linear", PatchedMoeFP8Matmul),
"ReplicatedLinear": ModuleInfo("linear", PatchedReplicatedLinear),
"VllmMixtureOfExpertsOpFP8": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8),
# FIXME (Yi) revert change
"FusedMoE": ModuleInfo("linear", PatchedMixtralMoE, False),
"GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
"VllmMixtureOfExpertsOp": (
ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV2)
if os.getenv("LOW_CPU_MEM", "0") == "1"
else ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV1)
),
# "GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
# "VllmMixtureOfExpertsOp": (
# ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV2)
# if os.getenv("LOW_CPU_MEM", "0") == "1"
# else ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV1)
# ),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator, INCAcceleratorType
cur_accelerator = auto_detect_accelerator()

from neural_compressor.torch.utils import environ
from neural_compressor.common.utils import logger

descale_fcn = lambda x, scale: torch.mul(x, scale)
scale_fcn = lambda x, scale: torch.div(x, scale)
cast_fcn = lambda x, dtype: x.to(dtype=dtype)
Expand Down Expand Up @@ -106,6 +109,9 @@ def get_fp8_hw_alligned_scales(dtype, device):
}

def calc_maxabs_scale(xmaxabs, fullscale, backoff=1):
if environ.INC_FORCE_NAIVE_SCALING:
backoff = 1.0
logger.warning_once(f"Enabled naive scaling, backoff is set to {backoff}")
scale = xmaxabs / (fullscale * backoff)
return scale

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import time
from .._quant_common.quant_config import MeasureExclude, QuantMode, ScaleMethod, get_hqt_config, set_hqt_config
# from ..utils.logger import logger
from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL
from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL, maybe_dequant_original_fp8_weight
from .common import *
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
from neural_compressor.torch.algorithms.fp8_quant.model_configs import (
Expand Down Expand Up @@ -149,6 +149,10 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
logger.info(f"Patching measure module {name} {mod.__class__}")
num_info += 1
set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module
if mod_type == "dynamic_moe" and hasattr(mod, "num_experts"):
# override default number of outputs for dynamic moe
mod_types[mod_type].num_outputs = mod.num_experts+1
logger.warning(f"Dynamic moe num_outputs set to {mod.num_experts+1}")
mod_extra_config = (
init_measure_object(
mod,
Expand All @@ -167,7 +171,10 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
# logger.info(f"Pacthed module pmod: {pmod}")
if pmod._mod_extra_config:
for param_name in pmod._mod_extra_config.params:
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
param = getattr(pmod, param_name)
param = maybe_dequant_original_fp8_weight(pmod.orig_mod, param)
if config["measure_on_hpu"]:
param = param.to(cur_accelerator.name())
pmod._mod_extra_config.params[param_name].measure(param)
Expand Down
27 changes: 21 additions & 6 deletions neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import time
cur_accelerator = auto_detect_accelerator()

from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL
from neural_compressor.torch.algorithms.fp8_quant._core.common import INFO_INTERVAL, maybe_dequant_original_fp8_weight


@torch.no_grad()
Expand Down Expand Up @@ -78,9 +78,13 @@ def quantize_params(mod, mod_extra_config):
param = getattr(mod, param_name)
if param.dtype == torch.float16:
param = param.to(torch.bfloat16)
logger.debug(f"Quantizing parameter {param_name} of module {mod.__class__.__name__}")
param = maybe_dequant_original_fp8_weight(mod, param)
quantized_param = quantizer(param.to(cur_accelerator.name()))
delattr(mod, param_name)
setattr(mod, param_name, nn.Parameter(quantized_param))
# Note: in case of re-quantize the fp8 weights, we need to set `updated_fp8_weight` to True
mod.updated_fp8_weight = True
quantized_param = getattr(mod, param_name)
quantized_param.requires_grad_(False)
cur_accelerator.synchronize()
Expand Down Expand Up @@ -165,27 +169,38 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
scale_config, save_file)
if not config.cfg["fake_quant"] and mod_default_dict[mod_type_str].should_measure_and_quant:
quantize_params(mod, mod_extra_config)
logger.debug(f"patching module {name}")
# logger.debug(f"patching module {name}")
patch_module(mod, mod_extra_config, mod_default_dict)
name = origin_name
patched_modules.append(name)
patched_module_types.add(type(mod))
htcore.mark_step()
logger.debug("Patched module name: %s", name)
cur_accelerator.synchronize()
if save_file: # cache calculated scales
save_scales(model, scales_obj, scales_file_format, scale_file + ".npz")
save_scales(model, scales_obj, scales_file_format, scale_file + ".json")
logger.debug("Patched module types: %s", patched_module_types)
logger.debug("Patched modules: %s", patched_modules)
logger.debug("Total patched modules: %d", len(patched_modules))

show_mem_info("before move all")
model = model.to(cur_accelerator.name())
for _, mod in model.named_modules():
if hasattr(mod, "post_process"):
mod.post_process()
torch.distributed.barrier()
show_mem_info("after move all")
postporcess_after_convert_(model)
show_mem_info("after post process")
convert_fp16_to_bf16(model)
show_mem_info("after convert_fp16_to_bf16")
cur_accelerator.synchronize()
show_mem_info("after synchronize")
torch.distributed.barrier()

def postporcess_after_convert_(model):
for _, mod in model.named_modules():
if hasattr(mod, "post_process"):
mod.post_process()
# Note: It is very important to synchronize after each post_process to avoid OoM.
cur_accelerator.synchronize()

def prepare_model_with_dummy_measurement(model, mod_list, scaling_method_name, scale_config):
"""Aim for loading, replace module with patched module for model on meta device.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from .scales_method import QuantTensorType
from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput
from neural_compressor.common import utils as inc_utils

# from neural_compressor.torch.algorithms.fp8_quant.utils import
from neural_compressor.torch.algorithms.fp8_quant._core.common import maybe_dequant_original_fp8_weight
class BaseOpQuantizer:

def __init__(self, config, mod, measurement, params, op_type):
Expand Down Expand Up @@ -94,9 +95,11 @@ def get_scales_module_config(self):
input_scales = self.calc_input_scales(num_of_inputs=1)
output_measurement = self.measurement.outputs[0] if self.measurement is not None else []
rescaled_weight = self.mod.weight if hasattr(self.mod, 'weight') else None
if rescaled_weight is not None:
rescaled_weight = maybe_dequant_original_fp8_weight(self.mod, rescaled_weight)
if self.weight_ich_scale_calc is not None:
weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST)
rescaled_weight = torch.div(self.mod.weight, weight_scales_in_ch.reshape([1, -1]))
rescaled_weight = torch.div(rescaled_weight, weight_scales_in_ch.reshape([1, -1]))
weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(rescaled_weight, QuantTensorType.CONST)
params_config = (
{"weight": weights_scales_out_ch}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .round_scales_function import *
from ..common import get_device_type_for_scales
from .scales_method import *

from neural_compressor.torch.utils import environ
from neural_compressor.common.utils import logger

class QuantTensorName(Enum):
INPUT = auto()
Expand All @@ -40,6 +41,9 @@ class ScaleValueType(Enum):

def parse_rounding_method(config, device_for_scales):
round_method = ScaleIdentity()
if environ.INC_FORCE_NAIVE_SCALING:
logger.warning_once("Enabled naive scaling")
return round_method
if "single" in config and "hw" in config:
round_method = ScaleHwAlignedFixed(device_for_scales)
elif "unit" in config:
Expand Down
Loading