Skip to content

[Refactor] Collect scattered w8a8-dynamic quantization operations #1111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ jobs:
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
fi

Expand Down Expand Up @@ -218,5 +219,6 @@ jobs:
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
fi
19 changes: 18 additions & 1 deletion tests/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
from unittest.mock import patch

import vllm # noqa: F401
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams

from tests.conftest import VllmRunner
Expand Down Expand Up @@ -95,3 +95,20 @@ def test_models_distributed_DeepSeek_dbo():
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


def test_models_distributed_DeepSeek_W8A8():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5

with VllmRunner(
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
max_model_len=8192,
enforce_eager=True,
dtype="auto",
tensor_parallel_size=4,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
110 changes: 4 additions & 106 deletions vllm_ascend/models/deepseek_dbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import torch
import torch.distributed as dist
import torch_npu
import torch_npu # noqa: F401
import vllm.envs as envs
from torch import nn
from transformers import PretrainedConfig
Expand All @@ -40,13 +40,10 @@
get_tp_group, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
UnquantizedLinearMethod)
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
Expand All @@ -67,6 +64,7 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.context import (
advance_step_multistream_layer_context, get_multistream_comm_context,
Expand All @@ -78,117 +76,17 @@
make_multistream_metadata_ds)
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor

VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2


class CustomDeepseekDBOMLP(nn.Module):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()

# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
self.is_dynamic_quant = not isinstance(
self.gate_up_proj.quant_method,
UnquantizedLinearMethod) and isinstance(
self.gate_up_proj.quant_method.quant_method,
AscendW8A8DynamicLinearMethod)

def forward(self, x):
if self.is_dynamic_quant:
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
x = torch_npu.npu_quant_matmul(
x,
self.gate_up_proj.weight,
self.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=x,
weight_scale=self.gate_up_proj.weight_scale_fp32,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)
x = torch_npu.npu_quant_matmul(
x,
self.down_proj.weight,
self.down_proj.weight_scale,
pertoken_scale=dynamic_scale,
output_dtype=torch.bfloat16,
)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):

def _forward_ms_mlp(self, x):
current_ms_metadata = get_multistream_comm_context()
assert current_ms_metadata is not None
if self.is_dynamic_quant:
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
x = torch_npu.npu_quant_matmul(
x,
self.gate_up_proj.weight,
self.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=x,
weight_scale=self.gate_up_proj.weight_scale_fp32,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)
x = torch_npu.npu_quant_matmul(
x,
self.down_proj.weight,
self.down_proj.weight_scale,
pertoken_scale=dynamic_scale,
output_dtype=torch.bfloat16,
)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
x = tensor_model_parallel_all_reduce(x)
current_ms_metadata.after_comm_event.record()
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
current_ms_metadata.before_comm_event.record()
Expand Down
92 changes: 56 additions & 36 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
# """Inference-only DeepseekV2/DeepseekV3 model."""

from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -69,12 +69,38 @@
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor

VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2


class CustomDeepseekV2SiluAndMul(SiluAndMul):

def __init__(self,
*,
weight_scale: Optional[Callable[[], torch.Tensor]] = None):
super().__init__()
self.weight_scale = weight_scale

def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
torch.Tensor]]):
if isinstance(x, tuple):
assert self.weight_scale is not None
# For AscendW8A8DynamicLinearMethod:
# a dynamic scale is passed along with the quantized value.
quantized_x, dynamic_scale = x
return torch_npu.npu_dequant_swiglu_quant(
x=quantized_x,
weight_scale=self.weight_scale(),
activation_scale=dynamic_scale,
activate_left=True,
quant_mode=1)
else:
return super().forward_oot(x)


class CustomDeepseekV2MLP(nn.Module):

def __init__(
Expand All @@ -101,44 +127,38 @@ def __init__(
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()

# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
self.is_dynamic_quant = not isinstance(
self.gate_up_proj.quant_method,
UnquantizedLinearMethod) and isinstance(
self.gate_up_proj.quant_method.quant_method,
AscendW8A8DynamicLinearMethod)
quant_method = self.gate_up_proj.quant_method
if isinstance(quant_method, UnquantizedLinearMethod):
self.act_fn = CustomDeepseekV2SiluAndMul()
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
# TODO(sdmyzlp): Currently preserved as before:
# 1. The only quantization supported for silu is W8A8Dynamic
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
#
# Maybe one can implement a better and more general configuration
# scheme, e.g. by somehow passing around the tweaked `quant_config`
self.act_fn = CustomDeepseekV2SiluAndMul(
# Use lazy binding, for `weight_scale_fp32` is accessible
# only after `process_weights_after_loading`.
weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
self.gate_up_proj._ascend_quant_config = {
"output_dtype": torch.int32,
"pertoken_scale": False,
"return_scale": True,
}
self.down_proj._ascend_quant_config = {
"output_dtype": torch.bfloat16,
"pertoken_scale": True,
"return_scale": False,
}
else:
raise NotImplementedError(
f"Quantization with [{type(quant_method)}] is NOT supported")

def forward(self, x):
if self.is_dynamic_quant:
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
x = torch_npu.npu_quant_matmul(
x,
self.gate_up_proj.weight,
self.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=x,
weight_scale=self.gate_up_proj.weight_scale_fp32,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)
x = torch_npu.npu_quant_matmul(
x,
self.down_proj.weight,
self.down_proj.weight_scale,
pertoken_scale=dynamic_scale,
output_dtype=torch.bfloat16,
)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
Expand Down
Loading