Skip to content

Support multistream of shared experts in FusedMoE #997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 11, 2025

Conversation

sdmyzlp
Copy link
Contributor

@sdmyzlp sdmyzlp commented May 29, 2025

Contains on #1111 for completeness.

What this PR does / why we need it?

Implement multi-stream parallelism for MoE layers with shared experts, where computation of shared experts will be overlapped with expert token dispatch and combine. Also, when multi-stream is enabled, weights of shared experts will be force to replicate across all cards, regardless of any tensor parallelism configurations, to avoid AllReduce operations.

With the expected overlaping being:

| shared gate_up | shared act |              | shared down |
|    dispatch    | routed gate_up, act, down |   combine   |

Does this PR introduce any user-facing change?

No.

How was this patch tested?

Tested on 1x16 910 node, with tailored 2 layer DSKv2.

@sdmyzlp sdmyzlp force-pushed the br_infer_opt branch 8 times, most recently from 6ed1be1 to db76771 Compare May 30, 2025 01:50
@MengqingCao MengqingCao requested a review from ganyi1996ppo May 30, 2025 01:57
@MengqingCao
Copy link
Collaborator

You can run bash format.sh locally to fix lint failures

@sdmyzlp sdmyzlp force-pushed the br_infer_opt branch 4 times, most recently from d6286aa to 52ba3d3 Compare May 30, 2025 07:13
@sdmyzlp
Copy link
Contributor Author

sdmyzlp commented May 30, 2025

You can run bash format.sh locally to fix lint failures

Done, with mypy error code import-not-found locally disabled for vllm_ascend/utils.py.

@ganyi1996ppo
Copy link
Collaborator

@sdmyzlp Can you upload the profiling graph on this part? so we can have a more intuisive perspective on this PR.

@@ -83,11 +85,20 @@ def fused_experts_with_mc2(
}
kwargs.update(stage1_kwargs)

if shared_experts is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tried to launch this after the dispatch op? If the dispatch will block enough time on the first stream, launch this after it seems can better overlap the execution part over host side right?

Copy link
Contributor Author

@sdmyzlp sdmyzlp Jun 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, with all secondary-stream operations launched after the corresponding main-stream operation.

By the way, this patch only means to implement multi-stream shared experts for graph mode decode, operations will still be executed sequentially otherwise. One may extend npu_switch_stream / npu_wait_tensor in future to support eager mode multi-stream functionality.

@sdmyzlp sdmyzlp force-pushed the br_infer_opt branch 4 times, most recently from e2542aa to 7b28633 Compare June 1, 2025 03:33
@sdmyzlp
Copy link
Contributor Author

sdmyzlp commented Jun 3, 2025

@sdmyzlp Can you upload the profiling graph on this part? so we can have a more intuisive perspective on this PR.

Described expected overlap using ascii in commit message, I have trouble uploading screenshot of the profiling.

@sdmyzlp sdmyzlp force-pushed the br_infer_opt branch 2 times, most recently from ec0553e to 1ef0f68 Compare June 4, 2025 06:20
Copy link

github-actions bot commented Jun 9, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@sdmyzlp sdmyzlp force-pushed the br_infer_opt branch 2 times, most recently from dbfdfe8 to fb0cdf2 Compare June 9, 2025 13:07
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

sdmyzlp added 7 commits June 11, 2025 01:02
AscendW8A8DynamicLinearMethod is integrated into CustomDeepseekV2MLP
in a very awkward way, causing scattered quantization operations all
over the model scripts. Refactor to solve this problem.

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
The model of chosen is vllm-ascend/DeepSeek-V2-Lite-W8A8.

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1. Concentrate the usage of `enable_multistream_moe` to one single place,
   and sink the computation of shared experts to `self.experts()` when
   multistream MoE is enabled, regardless of decode or prefill.

2. Move computation of shared experts out of `apply_mlp`.

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
Helps to unite pathes where multistream is turned on or off.

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
As the replicated version of MergedColumnParallelLinear, aiming at
removing TP communication of DeepSeek-V2's `gate_up_proj` linear.
Also, with replicated weight, the chunked input hidden_states can
be used by shared experts.

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
With the expected overlaping being:
```
| shared gate_up | shared act |              | shared down |
|    dispatch    | routed gate_up, act, down |   combine   |
```

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
@ganyi1996ppo ganyi1996ppo merged commit 7bdc606 into vllm-project:main Jun 11, 2025
19 checks passed
ganyi1996ppo added a commit that referenced this pull request Jun 11, 2025
raindaywhu pushed a commit to raindaywhu/vllm-ascend that referenced this pull request Jun 16, 2025
Contains on vllm-project#1111 for completeness.

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.

With the expected overlaping being:
```
| shared gate_up | shared act |              | shared down |
|    dispatch    | routed gate_up, act, down |   combine   |
```

<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

No.

<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
momo609 pushed a commit to momo609/vllm-ascend that referenced this pull request Jun 17, 2025
Contains on vllm-project#1111 for completeness.

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.

With the expected overlaping being:
```
| shared gate_up | shared act |              | shared down |
|    dispatch    | routed gate_up, act, down |   combine   |
```

<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
No.

<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
momo609 pushed a commit to momo609/vllm-ascend that referenced this pull request Jun 17, 2025
Contains on vllm-project#1111 for completeness.

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.

With the expected overlaping being:
```
| shared gate_up | shared act |              | shared down |
|    dispatch    | routed gate_up, act, down |   combine   |
```

<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
No.

<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
Signed-off-by: wangxiaoxin (A) <wangxiaoxin7@huawei.com>
momo609 pushed a commit to momo609/vllm-ascend that referenced this pull request Jun 17, 2025
Contains on vllm-project#1111 for completeness.

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.

With the expected overlaping being:
```
| shared gate_up | shared act |              | shared down |
|    dispatch    | routed gate_up, act, down |   combine   |
```

<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
No.

<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
Signed-off-by: wangxiaoxin (A) <wangxiaoxin7@huawei.com>
shiyuan680 pushed a commit to raindaywhu/vllm-ascend that referenced this pull request Jul 7, 2025
Contains on vllm-project#1111 for completeness.

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.

With the expected overlaping being:
```
| shared gate_up | shared act |              | shared down |
|    dispatch    | routed gate_up, act, down |   combine   |
```

<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

No.

<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants