Skip to content

Commit f6cf92e

Browse files
[quant][bugfix] fix deepseek quant bug (#478)
see #465 Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: zzzzwwjj <1183291235@qq.com>
1 parent 579d858 commit f6cf92e

File tree

2 files changed

+7
-108
lines changed

2 files changed

+7
-108
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 3 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# See the License for the specific language governing permissions and
2525
# limitations under the License.
2626
"""Inference-only DeepseekV2/DeepseekV3 model."""
27-
from typing import Iterable, List, Optional, Set, Tuple, Union
27+
from typing import List, Optional, Union
2828

2929
import torch
3030
from torch import nn
@@ -40,14 +40,12 @@
4040
from vllm.model_executor.layers.sampler import get_sampler
4141
from vllm.model_executor.layers.vocab_parallel_embedding import (
4242
ParallelLMHead, VocabParallelEmbedding)
43-
from vllm.model_executor.model_loader.weight_utils import (
44-
default_weight_loader, maybe_remap_kv_scale_name)
4543
from vllm.model_executor.models.deepseek_v2 import ( # noqa
4644
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
4745
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE)
4846
from vllm.model_executor.models.utils import (
49-
PPMissingLayer, is_pp_missing_parameter,
50-
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
47+
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
48+
maybe_prefix)
5149
from vllm.sequence import IntermediateTensors
5250

5351

@@ -282,109 +280,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
282280
self.make_empty_intermediate_tensors = (
283281
self.model.make_empty_intermediate_tensors)
284282

285-
def load_weights(self, weights: Iterable[Tuple[str,
286-
torch.Tensor]]) -> Set[str]:
287-
stacked_params_mapping = [
288-
# (param_name, shard_name, shard_id)
289-
("gate_up_proj", "gate_proj", 0),
290-
("gate_up_proj", "up_proj", 1),
291-
]
292-
293-
# Params for weights, fp8 weight scales, fp8 activation scales
294-
# (param_name, weight_name, expert_id, shard_id)
295-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
296-
ckpt_gate_proj_name="gate_proj",
297-
ckpt_down_proj_name="down_proj",
298-
ckpt_up_proj_name="up_proj",
299-
num_experts=self.config.n_routed_experts)
300-
301-
params_dict = dict(self.named_parameters())
302-
loaded_params: Set[str] = set()
303-
for name, loaded_weight in weights:
304-
if "rotary_emb.inv_freq" in name:
305-
continue
306-
307-
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
308-
if spec_layer is not None:
309-
continue # skip spec decode layers for main model
310-
311-
# w8a8 weight from modelslim need flatten before load_weight
312-
if "scale" in name or "offset" in name:
313-
loaded_weight = loaded_weight.flatten()
314-
315-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
316-
# Skip non-stacked layers and experts (experts handled below).
317-
if weight_name not in name:
318-
continue
319-
# We have mlp.experts[0].gate_proj in the checkpoint.
320-
# Since we handle the experts below in expert_params_mapping,
321-
# we need to skip here BEFORE we update the name, otherwise
322-
# name will be updated to mlp.experts[0].gate_up_proj, which
323-
# will then be updated below in expert_params_mapping
324-
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
325-
if (("mlp.experts." in name) and name not in params_dict):
326-
continue
327-
name = name.replace(weight_name, param_name)
328-
# Skip loading extra bias for GPTQ models.
329-
if name.endswith(".bias") and name not in params_dict:
330-
continue
331-
332-
if is_pp_missing_parameter(name, self):
333-
continue
334-
335-
param = params_dict[name]
336-
weight_loader = param.weight_loader
337-
weight_loader(param, loaded_weight, shard_id)
338-
break
339-
else:
340-
for mapping in expert_params_mapping:
341-
param_name, weight_name, expert_id, shard_id = mapping
342-
if weight_name not in name:
343-
continue
344-
name = name.replace(weight_name, param_name)
345-
346-
if is_pp_missing_parameter(name, self):
347-
continue
348-
349-
param = params_dict[name]
350-
weight_loader = param.weight_loader
351-
weight_loader(param,
352-
loaded_weight,
353-
name,
354-
shard_id=shard_id,
355-
expert_id=expert_id)
356-
break
357-
else:
358-
# Skip loading extra bias for GPTQ models.
359-
if name.endswith(".bias") and name not in params_dict:
360-
continue
361-
362-
# Remapping the name of FP8 kv-scale.
363-
name = maybe_remap_kv_scale_name(name, params_dict)
364-
if name is None:
365-
continue
366-
367-
if is_pp_missing_parameter(name, self):
368-
continue
369-
370-
param = params_dict[name]
371-
weight_loader = getattr(param, "weight_loader",
372-
default_weight_loader)
373-
weight_loader(param, loaded_weight)
374-
loaded_params.add(name)
375-
return loaded_params
376-
377283

378284
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
379285
pass
380-
381-
382-
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
383-
weight_name: str) -> Optional[int]:
384-
if hasattr(config, "num_nextn_predict_layers") and (
385-
config.num_nextn_predict_layers > 0):
386-
layer_idx = config.num_hidden_layers
387-
for i in range(config.num_nextn_predict_layers):
388-
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
389-
return layer_idx + i
390-
return None

vllm_ascend/quantization/quant_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,7 @@ def apply(
374374
num_expert_group,
375375
custom_routing_function, scoring_func,
376376
e_score_correction_bias)
377+
378+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
379+
if hasattr(self.quant_method, "process_weights_after_loading"):
380+
self.quant_method.process_weights_after_loading(layer)

0 commit comments

Comments
 (0)