Skip to content

Conversation

anon189Ty
Copy link
Contributor

@anon189Ty anon189Ty commented Oct 10, 2025

What this PR does / why we need it?

Currently, when executing to the Linear layer of models in vLLM-Ascend, the weights format is ND in unquantized case and skipped ascend case.
This PR supplements the execution logic for Linear layer. We use a new global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and CANN version is 8.3, the weights of the Linear layer will be converted to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as w8a8-quantized case.

Does this PR introduce any user-facing change?

Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ format, you should set VLLM_ASCEND_ENABLE_NZ=1.

How was this patch tested?

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new environment variable VLLM_ASCEND_ENABLE_NZ to control the conversion of weights to the FRACTAL_NZ format, which is a valuable addition for performance tuning on Ascend hardware. The changes are applied consistently across various quantization methods and models. However, I've identified a few critical issues in the test files that would prevent the test suite from running, and a potential logic bug in vllm_ascend/attention/mla_v1.py involving dead code and an incorrect format constant. These issues need to be addressed to ensure the correctness and stability of the codebase.

Comment on lines 141 to 146
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(isinstance(linear.quant_method,
AscendUnquantizedLinearMethod))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This code is at the class level, which will cause a NameError because self is not defined in this context. This code should be moved inside a test method, for example test_init.

Suggested change
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(isinstance(linear.quant_method,
AscendUnquantizedLinearMethod))
def test_init(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(isinstance(linear.quant_method,
AscendUnquantizedLinearMethod))

Comment on lines 551 to 554
elif isinstance(layer.quant_method, AscendUnquantizedLinearMethod):
if getattr(layer.quant_method, "unquant_to_nz", False):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code appears to be dead code. The condition getattr(layer.quant_method, "unquant_to_nz", False) will likely never be true because AscendUnquantizedLinearMethod.process_weights_after_loading sets self.unquant_to_nz = False.

Furthermore, even if this code were to be executed, it casts the weight to ACL_FORMAT_FRACTAL_ND, which is inconsistent with the pull request's goal of converting to FRACTAL_NZ.

If this logic is intended to be used, please correct the condition and the format. Otherwise, it should be removed.

layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
ACL_FORMAT_FRACTAL_NZ)
if envs_ascend.VLLM_ASCEND_ENABLE_NZ:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check can be extracted into a common function.

"8,3"):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
self.unquant_to_nz = False
Copy link
Collaborator

@weijinqian0 weijinqian0 Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may can remove this parameter

@anon189Ty anon189Ty force-pushed the unquant_nz_and_control branch from 4b07f81 to 6a3575f Compare October 11, 2025 10:07
@anon189Ty anon189Ty force-pushed the unquant_nz_and_control branch 3 times, most recently from 24113ff to eb634f2 Compare October 13, 2025 06:03
@realliujiaxu
Copy link
Contributor

why not set NZ by default, but instead add an environment variable for control?



class CustomRowParallelOp(CustomTensorParallelOp):

Copy link
Collaborator

@weijinqian0 weijinqian0 Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont need inherit this base class

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants